Skip to content

Commit

Permalink
Merge pull request #264 from mlr-org/nnet
Browse files Browse the repository at this point in the history
feat: allow formula as argument for nnet learner
  • Loading branch information
sebffischer authored Mar 24, 2023
2 parents 856d1d0 + efe7559 commit bedb629
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 13 deletions.
6 changes: 5 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# mlr3learners 0.5.5
# mlr3learners 0.5.6-9000

* Added formula argument to `nnet` learner and support feature type `"integer"`

# mlr3learners 0.5.6

- Enable new early stopping mechanism for xgboost.
- Improved documentation.
Expand Down
14 changes: 10 additions & 4 deletions R/LearnerClassifNnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
#' - Adjusted default: 3L.
#' - Reason for change: no default in `nnet()`.
#'
#' @section Custom mlr3 parameters:
#' - `formula`: if not provided, the formula is set to `task$formula()`.
#'
#' @references
#' `r format_bib("ripley_1996")`
#'
Expand Down Expand Up @@ -46,14 +49,15 @@ LearnerClassifNnet = R6Class("LearnerClassifNnet",
size = p_int(0L, default = 3L, tags = "train"),
skip = p_lgl(default = FALSE, tags = "train"),
subset = p_uty(tags = "train"),
trace = p_lgl(default = TRUE, tags = "train")
trace = p_lgl(default = TRUE, tags = "train"),
formula = p_uty(tags = "train")
)
ps$values = list(size = 3L)

super$initialize(
id = "classif.nnet",
packages = c("mlr3learners", "nnet"),
feature_types = c("numeric", "factor", "ordered"),
feature_types = c("numeric", "factor", "ordered", "integer"),
predict_types = c("prob", "response"),
param_set = ps,
properties = c("twoclass", "multiclass", "weights"),
Expand All @@ -68,9 +72,11 @@ LearnerClassifNnet = R6Class("LearnerClassifNnet",
if ("weights" %in% task$properties) {
pv = insert_named(pv, list(weights = task$weights$weight))
}
f = task$formula()
if (is.null(pv$formula)) {
pv$formula = task$formula()
}
data = task$data()
invoke(nnet::nnet.formula, formula = f, data = data, .args = pv)
invoke(nnet::nnet.formula, data = data, .args = pv)
},

.predict = function(task) {
Expand Down
14 changes: 10 additions & 4 deletions R/LearnerRegrNnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
#' - Adjusted default: 3L.
#' - Reason for change: no default in `nnet()`.
#'
#' @section Custom mlr3 parameters:
#' - `formula`: if not provided, the formula is set to `task$formula()`.
#'
#' @references
#' `r format_bib("ripley_1996")`
#'
Expand Down Expand Up @@ -46,14 +49,15 @@ LearnerRegrNnet = R6Class("LearnerRegrNnet",
size = p_int(0L, default = 3L, tags = "train"),
skip = p_lgl(default = FALSE, tags = "train"),
subset = p_uty(tags = "train"),
trace = p_lgl(default = TRUE, tags = "train")
trace = p_lgl(default = TRUE, tags = "train"),
formula = p_uty(tags = "train")
)
ps$values = list(size = 3L)

super$initialize(
id = "regr.nnet",
packages = c("mlr3learners", "nnet"),
feature_types = c("numeric", "factor", "ordered"),
feature_types = c("numeric", "factor", "ordered", "integer"),
predict_types = c("response"),
param_set = ps,
properties = c("weights"),
Expand All @@ -68,10 +72,12 @@ LearnerRegrNnet = R6Class("LearnerRegrNnet",
if ("weights" %in% task$properties) {
pv = insert_named(pv, list(weights = task$weights$weight))
}
f = task$formula()
if (is.null(pv$formula)) {
pv$formula = task$formula()
}
data = task$data()
# force linout = TRUE for regression
invoke(nnet::nnet.formula, formula = f, data = data, linout = TRUE, .args = pv)
invoke(nnet::nnet.formula, data = data, linout = TRUE, .args = pv)
},

.predict = function(task) {
Expand Down
1 change: 0 additions & 1 deletion inst/paramtest/test_paramtest_classif.nnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ test_that("classif.nnet", {
"x", # handled via mlr3
"y", # handled via mlr3
"weights", # handled via mlr3
"formula", # handled via mlr3
"data", # handled via mlr3
"entropy", # automatically set to TRUE if two-class task
"softmax", # automatically set to TRUE if multi-class task
Expand Down
1 change: 0 additions & 1 deletion inst/paramtest/test_paramtest_regr.nnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ test_that("regr.nnet", {
"x", # handled via mlr3
"y", # handled via mlr3
"weights", # handled via mlr3
"formula", # handled via mlr3
"data", # handled via mlr3
"linout", # automatically set to TRUE, since it's the regression learner
"entropy", # mutually exclusive with linout
Expand Down
10 changes: 9 additions & 1 deletion man/mlr_learners_classif.nnet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 9 additions & 1 deletion man/mlr_learners_regr.nnet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit bedb629

Please sign in to comment.