Skip to content

Estimate Parameters for Arbitrary R Functions using 'Stan'

License

Unknown, MIT licenses found

Licenses found

Unknown
LICENSE
MIT
LICENSE.md
Notifications You must be signed in to change notification settings

andrjohns/StanEstimators

Folders and files

NameName
Last commit message
Last commit date

Latest commit

db7b965 · Jul 10, 2024
Jan 2, 2024
Jun 18, 2024
Jun 18, 2024
Jun 10, 2024
Jul 10, 2024
Jun 10, 2024
Jan 7, 2024
Dec 3, 2023
Jan 18, 2024
Jun 10, 2024
Dec 1, 2023
Dec 1, 2023
Jan 7, 2024
Jan 2, 2024
Jan 2, 2024
Dec 1, 2023

Repository files navigation

StanEstimators

CRAN status R-CMD-check StanEstimators status badge

The StanEstimators package provides an estimation back-end for R functions, similar to those provided by the optim package, using the algorithms provided by the Stan probabilistic programming language.

As Stan’s algorithms are gradient-based, function gradients can be automatically calculated using finite-differencing or the user can provide a function for analytical calculation.

Installation

You can install pre-built binaries using:

# we recommend running this is a fresh R session or restarting your current session
install.packages('StanEstimators', repos = c('https://andrjohns.r-universe.dev', 'https://cloud.r-project.org'))

Or you can build from source using:

# install.packages("remotes")
remotes::install_github("andrjohns/StanEstimators")

Usage

Consider the goal of estimating the mean and standard deviation of a normal distribution, with uniform uninformative priors on both parameters:

y N ( μ , σ )

μ U [ , ]

σ U [ 0 , ]

With known true values for verification:

y <- rnorm(500, 10, 2)

As with other estimation routines provided in R, we need to specify this as a function which takes a vector of parameters as its first argument and returns a single scalar value (the log-likelihood), as well as initial values for the parameters:

loglik_fun <- function(v, x) {
  sum(dnorm(x, v[1], v[2], log = TRUE))
}

inits <- c(0, 5)

Estimation time can also be significantly reduced by providing a gradient function, rather than relying on finite-differencing:

grad <- function(v, x) {
  inv_sigma <- 1 / v[2]
  y_scaled = (x - v[1]) * inv_sigma
  scaled_diff = inv_sigma * y_scaled
  c(sum(scaled_diff),
    sum(inv_sigma * (y_scaled*y_scaled) - inv_sigma)
  )
}

MCMC Estimation

Full MCMC estimation is provided by the stan_sample() function, which uses Stan’s default No U-Turn Sampler (NUTS) unless otherwise specified:

library(StanEstimators)

fit <- stan_sample(loglik_fun, inits, additional_args = list(y),
                   lower = c(-Inf, 0), # Enforce a positivity constraint for SD
                   num_chains = 1, seed = 1234)

We can see that the parameters were recovered accurately and that the estimation was relatively fast: ~1 sec for 1000 warmup and 1000 iterations

unlist(fit@timing)
#>   warmup sampling 
#>    0.527    0.490
summary(fit)
#> # A tibble: 3 × 10
#>   variable    mean  median     sd    mad      q5     q95  rhat ess_bulk ess_tail
#>   <chr>      <dbl>   <dbl>  <dbl>  <dbl>   <dbl>   <dbl> <dbl>    <dbl>    <dbl>
#> 1 lp__     -1.08e3 -1.08e3 1.03   0.749  -1.08e3 -1.08e3  1.00     507.     672.
#> 2 pars[1]   1.01e1  1.01e1 0.0940 0.0948  9.97e0  1.03e1  1.00     895.     671.
#> 3 pars[2]   2.11e0  2.10e0 0.0686 0.0670  2.00e0  2.22e0  1.00     860.     696.

Estimation time can be improved further by providing a gradient function:

fit_grad <- stan_sample(loglik_fun, inits, additional_args = list(y),
                        grad_fun = grad,
                        lower = c(-Inf, 0),
                        num_chains = 1,
                        seed = 1234)

Which shows that the estimation time was dramatically improved, now ~0.15 seconds for 1000 warmup and 1000 iterations.

unlist(fit_grad@timing)
#>   warmup sampling 
#>    0.111    0.087
summary(fit_grad)
#> # A tibble: 3 × 10
#>   variable    mean  median     sd    mad      q5     q95  rhat ess_bulk ess_tail
#>   <chr>      <dbl>   <dbl>  <dbl>  <dbl>   <dbl>   <dbl> <dbl>    <dbl>    <dbl>
#> 1 lp__     -1.08e3 -1.08e3 1.02   0.741  -1.08e3 -1.08e3  1.00     572.     712.
#> 2 pars[1]   1.01e1  1.01e1 0.0928 0.0943  9.97e0  1.03e1  1.00     950.     623.
#> 3 pars[2]   2.10e0  2.10e0 0.0691 0.0696  1.99e0  2.22e0  1.00     725.     613.

Optimization

opt_fd <- stan_optimize(loglik_fun, inits, additional_args = list(y),
                          lower = c(-Inf, 0),
                          seed = 1234)
opt_grad <- stan_optimize(loglik_fun, inits, additional_args = list(y),
                          grad_fun = grad,
                          lower = c(-Inf, 0),
                          seed = 1234)
summary(opt_fd)
#>       lp__ pars[1] pars[2]
#> 1 -1079.84 10.1221 2.09743
summary(opt_grad)
#>       lp__ pars[1] pars[2]
#> 1 -1079.84 10.1221 2.09743

Laplace Approximation

# Can provide the mode as a numeric vector:
lapl_num <- stan_laplace(loglik_fun, inits, additional_args = list(y),
                          mode = c(10, 2),
                          lower = c(-Inf, 0),
                          seed = 1234)

# Can provide the mode as a StanOptimize object:
lapl_opt <- stan_laplace(loglik_fun, inits, additional_args = list(y),
                          mode = opt_fd,
                          lower = c(-Inf, 0),
                          seed = 1234)

# Can estimate the mode before sampling:
lapl_est <- stan_laplace(loglik_fun, inits, additional_args = list(y),
                          lower = c(-Inf, 0),
                          seed = 1234)
summary(lapl_num)
#> # A tibble: 4 × 10
#>   variable     mean    median     sd    mad       q5        q95  rhat ess_bulk
#>   <chr>       <dbl>     <dbl>  <dbl>  <dbl>    <dbl>      <dbl> <dbl>    <dbl>
#> 1 log_p__  -1082.   -1082.    2.39   2.18   -1087.   -1080.     0.999     989.
#> 2 log_q__     -1.04    -0.692 1.04   0.716     -3.21    -0.0582 0.999    1047.
#> 3 pars[1]     10.0     10.0   0.0899 0.0867     9.85    10.1    1.00      933.
#> 4 pars[2]      2.00     2.00  0.0626 0.0635     1.90     2.11   1.00     1051.
#> # ℹ 1 more variable: ess_tail <dbl>
summary(lapl_opt)
#> # A tibble: 4 × 10
#>   variable     mean    median     sd    mad       q5        q95  rhat ess_bulk
#>   <chr>       <dbl>     <dbl>  <dbl>  <dbl>    <dbl>      <dbl> <dbl>    <dbl>
#> 1 log_p__  -1080.   -1080.    1.06   0.712  -1082.   -1079.     0.999    1044.
#> 2 log_q__     -1.04    -0.692 1.04   0.716     -3.21    -0.0582 0.999    1047.
#> 3 pars[1]     10.1     10.1   0.0940 0.0897     9.96    10.3    1.00      932.
#> 4 pars[2]      2.10     2.10  0.0688 0.0697     1.99     2.21   1.00     1051.
#> # ℹ 1 more variable: ess_tail <dbl>
summary(lapl_est)
#> # A tibble: 4 × 10
#>   variable     mean    median     sd    mad       q5        q95  rhat ess_bulk
#>   <chr>       <dbl>     <dbl>  <dbl>  <dbl>    <dbl>      <dbl> <dbl>    <dbl>
#> 1 log_p__  -1080.   -1080.    1.06   0.712  -1082.   -1079.     0.999    1044.
#> 2 log_q__     -1.04    -0.692 1.04   0.716     -3.21    -0.0582 0.999    1047.
#> 3 pars[1]     10.1     10.1   0.0940 0.0897     9.96    10.3    1.00      932.
#> 4 pars[2]      2.10     2.10  0.0688 0.0697     1.99     2.21   1.00     1051.
#> # ℹ 1 more variable: ess_tail <dbl>

Variational Inference

var_fd <- stan_variational(loglik_fun, inits, additional_args = list(y),
                              lower = c(-Inf, 0),
                              seed = 1234)
var_grad <- stan_variational(loglik_fun, inits, additional_args = list(y),
                              grad_fun = grad,
                              lower = c(-Inf, 0),
                              seed = 1234)
summary(var_fd)
#> # A tibble: 5 × 10
#>   variable     mean    median     sd    mad       q5        q95   rhat ess_bulk
#>   <chr>       <dbl>     <dbl>  <dbl>  <dbl>    <dbl>      <dbl>  <dbl>    <dbl>
#> 1 lp__         0        0     0      0          0        0      NA          NA 
#> 2 log_p__  -1081.   -1080.    1.33   0.986  -1083.   -1079.      0.999     997.
#> 3 log_g__     -1.03    -0.714 1.03   0.731     -3.29    -0.0486  1.00      959.
#> 4 pars[1]     10.2     10.2   0.0869 0.0898    10.1     10.4     1.00     1012.
#> 5 pars[2]      2.09     2.09  0.0650 0.0639     1.99     2.20    1.00      850.
#> # ℹ 1 more variable: ess_tail <dbl>
summary(var_grad)
#> # A tibble: 5 × 10
#>   variable     mean    median     sd    mad       q5        q95   rhat ess_bulk
#>   <chr>       <dbl>     <dbl>  <dbl>  <dbl>    <dbl>      <dbl>  <dbl>    <dbl>
#> 1 lp__         0        0     0      0          0        0      NA          NA 
#> 2 log_p__  -1081.   -1080.    1.33   0.986  -1083.   -1079.      0.999     997.
#> 3 log_g__     -1.03    -0.714 1.03   0.731     -3.29    -0.0486  1.00      959.
#> 4 pars[1]     10.2     10.2   0.0869 0.0898    10.1     10.4     1.00     1012.
#> 5 pars[2]      2.09     2.09  0.0650 0.0639     1.99     2.20    1.00      850.
#> # ℹ 1 more variable: ess_tail <dbl>

Pathfinder

path_fd <- stan_pathfinder(loglik_fun, inits, additional_args = list(y),
                              lower = c(-Inf, 0),
                              seed = 1234)
path_grad <- stan_pathfinder(loglik_fun, inits, additional_args = list(y),
                              grad_fun = grad,
                              lower = c(-Inf, 0),
                              seed = 1234)
summary(path_fd)
#> # A tibble: 4 × 10
#>   variable    mean  median     sd    mad      q5     q95  rhat ess_bulk ess_tail
#>   <chr>      <dbl>   <dbl>  <dbl>  <dbl>   <dbl>   <dbl> <dbl>    <dbl>    <dbl>
#> 1 lp_appr…  2.96e0  3.27e0 0.998  0.716   1.03e0  3.90e0 1.00      949.     909.
#> 2 lp__     -1.08e3 -1.08e3 1.04   0.726  -1.08e3 -1.08e3 1.00      946.     820.
#> 3 pars[1]   1.01e1  1.01e1 0.0955 0.0920  9.96e0  1.03e1 0.999    1004.     800.
#> 4 pars[2]   2.10e0  2.11e0 0.0668 0.0695  1.99e0  2.21e0 1.00      998.     907.
summary(path_grad)
#> # A tibble: 4 × 10
#>   variable    mean  median     sd    mad      q5     q95  rhat ess_bulk ess_tail
#>   <chr>      <dbl>   <dbl>  <dbl>  <dbl>   <dbl>   <dbl> <dbl>    <dbl>    <dbl>
#> 1 lp_appr…  2.96e0  3.27e0 0.998  0.716   1.03e0  3.90e0 1.00      949.     909.
#> 2 lp__     -1.08e3 -1.08e3 1.04   0.726  -1.08e3 -1.08e3 1.00      946.     820.
#> 3 pars[1]   1.01e1  1.01e1 0.0955 0.0920  9.96e0  1.03e1 0.999    1004.     800.
#> 4 pars[2]   2.10e0  2.11e0 0.0668 0.0695  1.99e0  2.21e0 1.00      998.     907.

About

Estimate Parameters for Arbitrary R Functions using 'Stan'

Resources

License

Unknown, MIT licenses found

Licenses found

Unknown
LICENSE
MIT
LICENSE.md

Stars

Watchers

Forks

Packages

No packages published

Languages