22# ' Check input parameters
33# '
44# ' @description
5- # ' Checks consistency in input (hyper) parameters for the cre function.
5+ # ' Checks consistency in input (hyper) parameters for the ` cre` function.
66# '
77# ' @param X_names The observed covariates names.
88# ' @param params The list of parameters required to run the function.
@@ -18,30 +18,18 @@ check_hyper_params <- function(X_names, params) {
1818 logger :: log_debug(" Checking hyper parameters..." )
1919
2020 # Input params checks --------------------------------------------------------
21- ntrees_rf <- getElement(params , " ntrees_rf " )
22- if (length(ntrees_rf ) == 0 ) {
23- ntrees_rf <- 20
21+ ntrees <- getElement(params , " ntrees " )
22+ if (length(ntrees ) == 0 ) {
23+ ntrees <- 20
2424 } else {
25- if (! inherits(ntrees_rf , " numeric" )) {
26- stop(" Invalid 'ntrees_rf ' input. Please input a number. " )
25+ if (! inherits(ntrees , " numeric" )) {
26+ stop(" Invalid 'ntrees ' input. Please input a positive integer " )
2727 }
28- }
29- params [[" ntrees_rf" ]] <- ntrees_rf
30-
31- ntrees_gbm <- getElement(params , " ntrees_gbm" )
32- if (length(ntrees_gbm ) == 0 ) {
33- ntrees_gbm <- 20
34- } else {
35- if (! inherits(ntrees_gbm , " numeric" )) {
36- stop(" Invalid 'ntrees_gbm' input. Please input a number." )
28+ if (ntrees < 1 ) {
29+ stop(" Invalid 'ntrees' input. Please input a positive integer" )
3730 }
3831 }
39- params [[" ntrees_gbm" ]] <- ntrees_gbm
40-
41- if (params [[" ntrees_gbm" ]] + params [[" ntrees_rf" ]] == 0 ) {
42- stop(" The total number of trees (ntrees_rf + ntrees_gbm) has to be
43- greater than 0" )
44- }
32+ params [[" ntrees" ]] <- ntrees
4533
4634 node_size <- getElement(params , " node_size" )
4735 if (length(node_size ) == 0 ) {
@@ -53,15 +41,15 @@ check_hyper_params <- function(X_names, params) {
5341 }
5442 params [[" node_size" ]] <- node_size
5543
56- max_nodes <- getElement(params , " max_nodes " )
57- if (length(max_nodes ) == 0 ) {
58- max_nodes <- 5
44+ max_rules <- getElement(params , " max_rules " )
45+ if (length(max_rules ) == 0 ) {
46+ max_rules <- 50
5947 } else {
60- if (! inherits(max_nodes , " numeric" )) {
61- stop(" Invalid 'max_nodes ' input. Please input a number." )
48+ if (! inherits(max_rules , " numeric" )) {
49+ stop(" Invalid 'max_rules ' input. Please input a number." )
6250 }
6351 }
64- params [[" max_nodes " ]] <- max_nodes
52+ params [[" max_rules " ]] <- max_rules
6553
6654 max_depth <- getElement(params , " max_depth" )
6755 if (length(max_depth ) == 0 ) {
@@ -119,58 +107,36 @@ check_hyper_params <- function(X_names, params) {
119107 }
120108 params [[" t_corr" ]] <- t_corr
121109
122- t_pvalue <- getElement(params , " t_pvalue " )
123- if (length(t_pvalue ) == 0 ) {
124- t_pvalue <- 0.05
110+ stability_selection <- getElement(params , " stability_selection " )
111+ if (length(stability_selection ) == 0 ) {
112+ stability_selection <- " vanilla "
125113 } else {
126- if (! inherits(t_pvalue , " numeric" )) {
127- stop(" Invalid 't_pvalue' input. Please input a number." )
114+ if (! (stability_selection %in% c(" error_control" , " no" ," vanilla" ))) {
115+ stop(paste0(" Invalid `stability_selection` argument. Please input " ,
116+ " a value among: {`no`, `vanilla`, `error_control`}." ))
128117 }
129118 }
130- params [[" t_pvalue " ]] <- t_pvalue
119+ params [[" stability_selection " ]] <- stability_selection
131120
132- stability_selection <- getElement(params , " stability_selection" )
133- pfer <- getElement(params , " pfer" )
134121 cutoff <- getElement(params , " cutoff" )
135- if (length(stability_selection ) == 0 ) {
136- stability_selection <- TRUE
137- pfer <- 1
122+ if (length(cutoff ) == 0 ) {
138123 cutoff <- 0.9
139124 } else {
140- if (! (stability_selection %in% c(TRUE , FALSE ))) {
141- stop(paste0(" Please specify 'TRUE' or 'FALSE' for" ,
142- " the stability_selection argument." ))
143- } else if (stability_selection ) {
144- if (length(pfer ) == 0 ) {
145- pfer <- 1
146- } else {
147- if (! inherits(pfer , " numeric" )) {
148- stop(" Invalid 'pfer' input. Please input a number." )
149- }
150- }
151- if (length(cutoff ) == 0 ) {
152- cutoff <- 0.9
153- } else {
154- if (! inherits(cutoff , " numeric" )) {
155- stop(" Invalid 'cutoff' input. Please input a number." )
156- }
157- }
125+ if (! inherits(cutoff , " numeric" )) {
126+ stop(" Invalid 'cutoff' input. Please input a number." )
158127 }
159128 }
160- params [[" stability_selection" ]] <- stability_selection
161- params [[" pfer" ]] <- pfer
162129 params [[" cutoff" ]] <- cutoff
163130
164-
165- penalty_rl <- getElement(params , " penalty_rl" )
166- if (length(penalty_rl ) == 0 ) {
167- penalty_rl <- 1
131+ pfer <- getElement(params , " pfer" )
132+ if (length(pfer ) == 0 ) {
133+ pfer <- 1
168134 } else {
169- if (! inherits(penalty_rl , " numeric" )) {
170- stop(" Invalid 'penalty_rl ' input. Please input a number." )
135+ if (! inherits(pfer , " numeric" )) {
136+ stop(" Invalid 'pfer ' input. Please input a number." )
171137 }
172138 }
173- params [[" penalty_rl " ]] <- penalty_rl
139+ params [[" pfer " ]] <- pfer
174140
175141 intervention_vars <- getElement(params , " intervention_vars" )
176142 if (length(intervention_vars ) == 0 ) {
@@ -196,6 +162,28 @@ check_hyper_params <- function(X_names, params) {
196162 }
197163 params [[" offset" ]] <- offset
198164
165+ # Check for correct B input
166+ B <- getElement(params , " B" )
167+ if (length(B ) == 0 ) {
168+ B <- 20
169+ } else {
170+ if (! inherits(B , " numeric" )) {
171+ stop(" Invalid 'B' input. Please input an integer." )
172+ }
173+ }
174+ params [[" B" ]] <- B
175+
176+ # Check for correct subsample imput
177+ subsample <- getElement(params , " subsample" )
178+ if (length(subsample ) == 0 ) {
179+ subsample <- 0.5
180+ } else {
181+ if (! inherits(subsample , " numeric" ) || (subsample < 0 ) || (subsample > 1 )) {
182+ stop(" Invalid 'subsample' input. Please input a number between 0 and 1." )
183+ }
184+ }
185+ params [[" subsample" ]] <- subsample
186+
199187 logger :: log_debug(" Done with checking hyper parameters." )
200188
201189 return (params )
0 commit comments