Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BART models give different predictions #1247

Open
balraadjsings opened this issue Jan 31, 2025 · 1 comment
Open

BART models give different predictions #1247

balraadjsings opened this issue Jan 31, 2025 · 1 comment

Comments

@balraadjsings
Copy link

balraadjsings commented Jan 31, 2025

The problem

I'm having trouble with using the final fitted model to generate predictions for the BART algorithm. Using the fitted workflow to predict data produces different predictions everytime it is run. Other algorithms from parsnip give the same predictions whenever they are run multiple times. Is this intended for BART models? I can get stable predictions when set.seed() is used but I'm not sure whether this is intended behavior or a bug.

The code below is an example from my original code but adapted to the iris dataset. I've run the code with and without parallel processing (in the current code parallel processing is turn off), but I get the same results (different predictions when run multiple times).

Reproducible example

## Bayesian additive regression trees (BART) ##

data_analysis <- iris


# Split data into 60% training and 40% test set
set.seed(123)

data_split <- initial_split(data_analysis,
                            prop=0.6)


data_train <- training(data_split)
data_test <- testing(data_split)

# 10-fold cross validation of training set
set.seed(345)
data_folds <- vfold_cv(data_train, v=10, repeats=1)


# Recipe

model_rec <- recipe(Sepal.Length ~ ., data=data_train) %>%
  step_normalize(all_numeric_predictors()) %>%
  step_dummy(all_nominal_predictors(), one_hot = T)



# Specify model and engine
bart_spec <- parsnip::bart(trees=tune(),
                           prior_terminal_node_coef=tune(),
                           prior_terminal_node_expo=tune(),
                           prior_outcome_range=tune()) %>%
  set_mode("regression") %>% 
  set_engine("dbarts")


# Create workflow with pre-processing recipe and model specification
bart_workflow <- workflow() %>%
  add_recipe(model_rec) %>% 
  add_model(bart_spec) 




# Set hyperparameter ranges
param <- extract_parameter_set_dials(bart_workflow) %>% 
  update(trees = trees(c(1000, 1500)),
         prior_terminal_node_coef=prior_terminal_node_coef(c(0.7, 1), NULL),
         prior_terminal_node_expo=prior_terminal_node_expo(c(1, 2)),
         prior_outcome_range=prior_outcome_range(c(0.25, 0.75), NULL))

# Parallel processing to optimize tuning (Note: can cause R to crash) 
ncores <- 4
cl <- makeCluster(ncores)
registerDoParallel(cl)

# Initial tuning
set.seed(456)
bart_tune <- tune_grid(bart_workflow, 
                       resamples=data_folds, 
                       grid=10,
                       control=control_grid(save_pred=T, allow_par=F, save_workflow=T),
                       metrics=metrics,
                       param_info = param)

stopCluster(cl) # Shut down cluster


bart_best_parms <- bart_tune %>% select_best(metric = "rmse") # Select best parameters based on RMSE

## Final model ##


# Update hyperparameters in recipe and model 
bart_final_rec <- finalize_recipe(model_rec, bart_best_parms)
bart_final_model <- finalize_model(bart_spec, bart_best_parms)

# Create new workflow for final model
bart_final_wf <- workflow() %>% 
  add_recipe(bart_final_rec) %>%
  add_model(bart_final_model)


# Run final model with best hyperparameters and evaluate on the test set
set.seed(678)
bart_final <- bart_final_wf %>%
  last_fit(data_split,
           metrics=metrics)



collect_metrics(bart_final) # Metrics of model (test set)


# These two functions should give the same predictions but they don't
# For other models built on other algorithms, both functions always produce the same results

collect_predictions(bart_final)

bart_final %>% extract_workflow() %>% predict(data_test) # This line of code produces different predictions each time unless set.seed is used prior to running it. The predictions are also different from collect_predictions()
@topepo
Copy link
Member

topepo commented Jan 31, 2025

For prediction, BART takes a sample from the predictive posterior distribution. If you set the seed right before running predict(), you'll get the same answer. For example:

> set.seed(1)
> bart_final %>% 
+  extract_workflow() %>% 
+  predict(data_test) %>% 
+  pluck(".pred") %>% 
+  head()
[1] 5.161443 4.613805 4.646856 4.951699 5.045785 5.118415
> 
> set.seed(2)
> bart_final %>% 
+  extract_workflow() %>% 
+  predict(data_test) %>% 
+  pluck(".pred") %>% 
+  head()
[1] 5.179751 4.620760 4.658856 4.953538 5.058127 5.118422
> 
> set.seed(2)
> bart_final %>% 
+  extract_workflow() %>% 
+  predict(data_test) %>% 
+  pluck(".pred") %>% 
+  head()
[1] 5.179751 4.620760 4.658856 4.953538 5.058127 5.118422

The problem you're seeing is that last_fit() uses some random numbers to fit the model and, when predict()` is called, the random number stream is different.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants