diff --git a/NEWS.md b/NEWS.md index 342a46dd..e269aac5 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,6 +1,7 @@ # mlr3learners 0.5.6-9000 * Added formula argument to `nnet` learner and support feature type `"integer"` +* Added `min.bucket` parameter to `classif.ranger` and `regr.ranger`. # mlr3learners 0.5.6 diff --git a/R/LearnerClassifRanger.R b/R/LearnerClassifRanger.R index 945dc402..c4c06a56 100644 --- a/R/LearnerClassifRanger.R +++ b/R/LearnerClassifRanger.R @@ -44,6 +44,7 @@ LearnerClassifRanger = R6Class("LearnerClassifRanger", importance = p_fct(c("none", "impurity", "impurity_corrected", "permutation"), tags = "train"), keep.inbag = p_lgl(default = FALSE, tags = "train"), max.depth = p_int(default = NULL, lower = 0L, special_vals = list(NULL), tags = "train"), + min.bucket = p_int(1L, default = 1L, tags = "train"), min.node.size = p_int(1L, default = NULL, special_vals = list(NULL), tags = "train"), min.prop = p_dbl(default = 0.1, tags = "train"), minprop = p_dbl(default = 0.1, tags = "train"), @@ -131,8 +132,10 @@ LearnerClassifRanger = R6Class("LearnerClassifRanger", newdata = ordered_features(task, self) prediction = invoke(predict, - self$model, data = newdata, - predict.type = "response", .args = pv) + self$model, + data = newdata, + predict.type = "response", .args = pv + ) if (self$predict_type == "response") { list(response = prediction$predictions) diff --git a/R/LearnerRegrRanger.R b/R/LearnerRegrRanger.R index 590ca4ba..6276429d 100644 --- a/R/LearnerRegrRanger.R +++ b/R/LearnerRegrRanger.R @@ -33,6 +33,7 @@ LearnerRegrRanger = R6Class("LearnerRegrRanger", importance = p_fct(c("none", "impurity", "impurity_corrected", "permutation"), tags = "train"), keep.inbag = p_lgl(default = FALSE, tags = "train"), max.depth = p_int(default = NULL, lower = 0L, special_vals = list(NULL), tags = "train"), + min.bucket = p_int(1L, default = 1L, tags = "train"), min.node.size = p_int(1L, default = 5L, special_vals = list(NULL), tags = "train"), min.prop = p_dbl(default = 0.1, tags = "train"), minprop = p_dbl(default = 0.1, tags = "train"),