vignettes/computational_algorithm1.Rmd
computational_algorithm1.Rmd
Computational algorithms such as variational inference (VI) can fail due to the inability of the approximation family to capture the true posterior, under/over penalizing tendencies of convergence metric, and slow convergence of the optimization process. We’ll discuss 3 examples:
In Example I a simple Poisson model is shown that is well handled
by default ADVI if the size of the data is small, but becomes
miscalibrated when larger amount of observations is available. It also
turns out that for such a simple model using optimizing
leads to very good results.
In Example II we discuss a Hidden Markov Model where the
approximation by ADVI is imperfect but not very wrong. We also show how
the (mis)calibration responds to changing parameters of the ADVI
implementation and that optimizing
performs worse than
ADVI.
In Example III we show that a small modification to the model from Example II makes the ADVI approximation perform much worse.
When the fit between posterior and approximation family, convergence metric and its process are checked so that efficiency is gained without sacrificing accuracy too much, VI can be applied. On top of its role as “the test” computational algorithms should pass, SBC provides informative inferential results which directly affect workflow decisions.
HMC can be slow and depending on the joint posterior (as a combination of data, prior, and likelihood) and the user’s goal, deterministic approximation algorithms can be an aid. To be specific, if the joint posterior is well-formed enough for reliable approximation (symmetric for ADVI which has normal approximation family) or the user only needs point estimate (i.e. specification up to distribution-level is not needed) users can consider the deterministic alternatives for their inference tool such as ADVI supported by Stan. Note that Pathfinder (Zhang, 2021) which blends deterministic algorithm’s efficiency and stochastic algorithm’s accuracy in a timely manner is under development. SBC provides one standard to test whether ADVI works well for your model without ever needing to run full HMC for your model.
Let’s start by setting up our environment.
library(SBC)
library(ggplot2)
library(cmdstanr)
library(rstan)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
# Parallel processing
library(future)
plan(multisession)
# The fits are very fast,
# so we force a minimum chunk size to reduce the overhead of
# paralellization and decrease computation time.
options(SBC.min_chunk_size = 5)
# Setup caching of results
cache_dir <- "./_approximate_computation_SBC_cache"
if(!dir.exists(cache_dir)) {
dir.create(cache_dir)
}
We’ll start by the extremely simple Poisson model already introduced in the basic usage vignette:
data{
int N;
array[N] int y;
}
parameters{
real<lower = 0> lambda;
}
model{
lambda ~ gamma(15, 5);
y ~ poisson(lambda);
}
And here’s R code that generates data matching that model:
poisson_generator_single <- function(N){
# N is the number of data points we are generating
lambda <- rgamma(n = 1, shape = 15, rate = 5)
y <- rpois(n = N, lambda = lambda)
list(
variables = list(
lambda = lambda
),
generated = list(
N = N,
y = y
)
)
}
We’ll start with Stan’s ADVI with all default parameters, i.e. a mean-field variational approximation. We compile the model and create a variational SBC backend.
model_poisson <- cmdstan_model("stan/poisson.stan")
## In file included from stan/lib/stan_math/lib/boost_1.81.0/boost/multi_array/multi_array_ref.hpp:32,
## from stan/lib/stan_math/lib/boost_1.81.0/boost/multi_array.hpp:34,
## from stan/lib/stan_math/lib/boost_1.81.0/boost/numeric/odeint/algebra/multi_array_algebra.hpp:22,
## from stan/lib/stan_math/lib/boost_1.81.0/boost/numeric/odeint.hpp:63,
## from stan/lib/stan_math/stan/math/prim/functor/ode_rk45.hpp:9,
## from stan/lib/stan_math/stan/math/prim/functor/integrate_ode_rk45.hpp:6,
## from stan/lib/stan_math/stan/math/prim/functor.hpp:16,
## from stan/lib/stan_math/stan/math/rev/fun.hpp:200,
## from stan/lib/stan_math/stan/math/rev.hpp:12,
## from stan/lib/stan_math/stan/math.hpp:19,
## from stan/src/stan/model/model_header.hpp:4,
## from C:/Users/Martin/AppData/Local/Temp/RtmpMFhe4R/model-294463922829.hpp:2:
## stan/lib/stan_math/lib/boost_1.81.0/boost/functional.hpp:180:45: warning: 'template<class _Arg, class _Result> struct std::unary_function' is deprecated [-Wdeprecated-declarations]
## 180 | : public boost::functional::detail::unary_function<typename unary_traits<Predicate>::argument_type,bool>
## | ^~~~~~~~~~~~~~
## In file included from C:/rtools43/ucrt64/include/c++/13.2.0/string:49,
## from C:/rtools43/ucrt64/include/c++/13.2.0/bits/locale_classes.h:40,
## from C:/rtools43/ucrt64/include/c++/13.2.0/bits/ios_base.h:41,
## from C:/rtools43/ucrt64/include/c++/13.2.0/ios:44,
## from C:/rtools43/ucrt64/include/c++/13.2.0/istream:40,
## from C:/rtools43/ucrt64/include/c++/13.2.0/sstream:40,
## from C:/rtools43/ucrt64/include/c++/13.2.0/complex:45,
## from stan/lib/stan_math/lib/eigen_3.4.0/Eigen/Core:50,
## from stan/lib/stan_math/lib/eigen_3.4.0/Eigen/Dense:1,
## from stan/lib/stan_math/stan/math/prim/fun/Eigen.hpp:22,
## from stan/lib/stan_math/stan/math/rev.hpp:4:
## C:/rtools43/ucrt64/include/c++/13.2.0/bits/stl_function.h:117:12: note: declared here
## 117 | struct unary_function
## | ^~~~~~~~~~~~~~
## stan/lib/stan_math/lib/boost_1.81.0/boost/functional.hpp:214:45: warning: 'template<class _Arg1, class _Arg2, class _Result> struct std::binary_function' is deprecated [-Wdeprecated-declarations]
## 214 | : public boost::functional::detail::binary_function<
## | ^~~~~~~~~~~~~~~
## C:/rtools43/ucrt64/include/c++/13.2.0/bits/stl_function.h:131:12: note: declared here
## 131 | struct binary_function
## | ^~~~~~~~~~~~~~~
## stan/lib/stan_math/lib/boost_1.81.0/boost/functional.hpp:252:45: warning: 'template<class _Arg, class _Result> struct std::unary_function' is deprecated [-Wdeprecated-declarations]
## 252 | : public boost::functional::detail::unary_function<
## | ^~~~~~~~~~~~~~
## C:/rtools43/ucrt64/include/c++/13.2.0/bits/stl_function.h:117:12: note: declared here
## 117 | struct unary_function
## | ^~~~~~~~~~~~~~
## stan/lib/stan_math/lib/boost_1.81.0/boost/functional.hpp:299:45: warning: 'template<class _Arg, class _Result> struct std::unary_function' is deprecated [-Wdeprecated-declarations]
## 299 | : public boost::functional::detail::unary_function<
## | ^~~~~~~~~~~~~~
## C:/rtools43/ucrt64/include/c++/13.2.0/bits/stl_function.h:117:12: note: declared here
## 117 | struct unary_function
## | ^~~~~~~~~~~~~~
## stan/lib/stan_math/lib/boost_1.81.0/boost/functional.hpp:345:57: warning: 'template<class _Arg, class _Result> struct std::unary_function' is deprecated [-Wdeprecated-declarations]
## 345 | class mem_fun_t : public boost::functional::detail::unary_function<T*, S>
## | ^~~~~~~~~~~~~~
## C:/rtools43/ucrt64/include/c++/13.2.0/bits/stl_function.h:117:12: note: declared here
## 117 | struct unary_function
## | ^~~~~~~~~~~~~~
## stan/lib/stan_math/lib/boost_1.81.0/boost/functional.hpp:361:58: warnin
## g: 'template<class _Arg1, class _Arg2, class _Result> struct std::binary_function' is deprecated [-Wdeprecated-declarations]
## 361 | class mem_fun1_t : public boost::functional::detail::binary_function<T*, A, S>
## | ^~~~~~~~~~~~~~~
## C:/rtools43/ucrt64/include/c++/13.2.0/bits/stl_function.h:131:12: note: declared here
## 131 | struct binary_function
## | ^~~~~~~~~~~~~~~
## stan/lib/stan_math/lib/boost_1.81.0/boost/functional.hpp:377:63: warning: 'template<class _Arg, class _Result> struct std::unary_function' is deprecated [-Wdeprecated-declarations]
## 377 | class const_mem_fun_t : public boost::functional::detail::unary_function<const T*, S>
## | ^~~~~~~~~~~~~~
## C:/rtools43/ucrt64/include/c++/13.2.0/bits/stl_function.h:117:12: note: declared here
## 117 | struct unary_function
## | ^~~~~~~~~~~~~~
## stan/lib/stan_math/lib/boost_1.81.0/boost/functional.hpp:393:64: warning: 'template<class _Arg1, class _Arg2, class _Result> struct std::binary_function' is deprecated [-Wdeprecated-declarations]
## 393 | class const_mem_fun1_t : public boost::functional::detail::binary_function<const T*, A, S>
## | ^~~~~~~~~~~~~~~
## C:/rtools43/ucrt64/include/c++/13.2.0/bits/stl_function.h:131:12: note: declared here
## 131 | struct binary_function
## | ^~~~~~~~~~~~~~~
## stan/lib/stan_math/lib/boost_1.81.0/boost/functional.hpp:438:61: warning: 'template<class _Arg, class _Result> struct std::unary_function' is deprecated [-Wdeprecated-declarations]
## 438 | class mem_fun_ref_t : public boost::functional::detail::unary_function<T&, S>
## | ^~~~~~~~~~~~~~
## C:/rtools43/ucrt64/include/c++/13.2.0/bits/stl_function.h:117:12: note: declared here
## 117 | struct unary_function
## | ^~~~~~~~~~~~~~
## stan/lib/stan_math/lib/boost_1.81.0/boost/functional.hpp:454:62: warning: 'template<class _Arg1, class _Arg2, class _Result> struct std::binary_function' is deprecated [-Wdeprecated-declarations]
## 454 | class mem_fun1_ref_t : public boost::functional::detail::binary_function<T&, A, S>
## | ^~~~~~~~~~~~~~~
## C:/rtools43/ucrt64/include/c++/13.2.0/bits/stl_function.
## h:131:12: note: declared here
## 131 | struct binary_function
## | ^~~~~~~~~~~~~~~
## stan/lib/stan_math/lib/boost_1.81.0/boost/functional.hpp:470:67: warning: 'template<class _Arg, class _Result> struct std::unary_function' is deprecated [-Wdeprecated-declarations]
## 470 | class const_mem_fun_ref_t : public boost::functional::detail::unary_function<const T&, S>
## | ^~~~~~~~~~~~~~
## C:/rtools43/ucrt64/include/c++/13.2.0/bits/stl_function.h:117:12: note: declared here
## 117 | struct unary_function
## | ^~~~~~~~~~~~~~
## stan/lib/stan_math/lib/boost_1.81.0/boost/functional.hpp:487:68: warning: 'template<class _Arg1, class _Arg2, class _Result> struct std::binary_function' is deprecated [-Wdeprecated-declarations]
## 487 | class const_mem_fun1_ref_t : public boost::functional::detail::binary_function<const T&, A, S>
## | ^~~~~~~~~~~~~~~
## C:/rtools43/ucrt64/include/c++/13.2.0/bits/stl_function.h:131:12: note: declared here
## 131 | struct binary_function
## | ^~~~~~~~~~~~~~~
## stan/lib/stan_math/lib/boost_1.81.0/boost/functional.hpp:533:73: warning: 'template<class _Arg, class _Result> struct std::unary_function' is deprecated [-Wdeprecated-declarations]
## 533 | class pointer_to_unary_function : public boost::functional::detail::unary_function<Arg,Result>
## | ^~~~~~~~~~~~~~
## C:/rtools43/ucrt64/include/c++/13.2.0/bits/stl_function.h:117:12: note: declared here
## 117 | struct unary_function
## | ^~~~~~~~~~~~~~
## stan/lib/stan_math/lib/boost_1.81.0/boost/functional.hpp:557:74: warning: 'template<class _Arg1, class _Arg2, class _Result> struct std::binary_function' is deprecated [-Wdeprecated-declarations]
## 557 | class pointer_to_binary_function : public boost::functional::detail::binary_fun
## ction<Arg1,Arg2,Result>
## | ^~~~~~~~~~~~~~~
## C:/rtools43/ucrt64/include/c++/13.2.0/bits/stl_function.h:131:12: note: declared here
## 131 | struct binary_function
## | ^~~~~~~~~~~~~~~
## In file included from stan/lib/stan_math/stan/math/prim/prob/von_mises_lccdf.hpp:5,
## from stan/lib/stan_math/stan/math/prim/prob/von_mises_ccdf_log.hpp:4,
## from stan/lib/stan_math/stan/math/prim/prob.hpp:359,
## from stan/lib/stan_math/stan/math/prim.hpp:16,
## from stan/lib/stan_math/stan/math/rev.hpp:16:
## stan/lib/stan_math/stan/math/prim/prob/von_mises_cdf.hpp: In function 'stan::return_type_t<T_x, T_sigma, T_l> stan::math::von_mises_cdf(const T_x&, const T_mu&, const T_k&)':
## stan/lib/stan_math/stan/math/prim/prob/von_mises_cdf.hpp:194: note: '-Wmisleading-indentation' is disabled from this point onwards, since column-tracking was disabled due to the size of the code/headers
## 194 | if (cdf_n < 0.0)
## |
## stan/lib/stan_math/stan/math/prim/prob/von_mises_cdf.hpp:194: note: adding '-flarge-source-files' will allow for more column-tracking support, at the expense of compilation time and memory
backend_poisson <- SBC_backend_cmdstan_variational(model_poisson, n_retries_init = 3)
Note that we allow the backend to retry initialization several times
(n_retries_init
), as the ADVI implementation in Stan can
sometimes fail to start properly on the first try even for very simple
models. This ability to retry is an extension in the SBC package and not
implemented in Stan.
Throughout the vignette, we’ll also use caching for the results.
Since the model runs quickly and is simple, we start with 1000 simulations.
set.seed(46522641)
ds_poisson <- generate_datasets(
SBC_generator_function(poisson_generator_single, N = 20),
n_sims = 1000)
res_poisson <-
compute_SBC(
ds_poisson, backend_poisson, keep_fits = FALSE,
cache_mode = "results", cache_location = file.path(cache_dir, "poisson"))
## Results loaded from cache file 'poisson'
Even with the quite high precision afforded by 1000 simulations, the ECDF diff plot and the ranks show no problems - the model is quite well calibrated, although the wavy shape of the ECDF suggest a minor overconfidence of the approximation:
plot_ecdf_diff(res_poisson)
plot_rank_hist(res_poisson)
To put this in different terms we can look at the observed coverage of central 50%, 80% and 95% intervals. We see that the observed coverage for 50% and 80% intervals is a bit lower than expected.
empirical_coverage(res_poisson$stats,width = c(0.95, 0.8, 0.5))
## # A tibble: 3 × 6
## variable width width_represented ci_low estimate ci_high
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 lambda 0.5 0.5 0.418 0.449 0.480
## 2 lambda 0.8 0.8 0.749 0.776 0.801
## 3 lambda 0.95 0.95 0.934 0.949 0.961
One would expect that the normal approximation implemented in ADVI
becomes better with increased size of the data, this is however not
necessarily true - let’s run the same model, but increase N
- the number of observed data points:
set.seed(23546224)
ds_poisson_100 <- generate_datasets(
SBC_generator_function(poisson_generator_single, N = 100),
n_sims = 1000)
res_poisson_100 <-
compute_SBC(ds_poisson_100, backend_poisson, keep_fits = FALSE,
cache_mode = "results", cache_location = file.path(cache_dir, "poisson_100"))
## Results loaded from cache file 'poisson_100'
In this case the model becomes clearly overconfident:
plot_ecdf_diff(res_poisson_100)
plot_rank_hist(res_poisson_100)
The empirical coverage of the central intervals confirms this:
empirical_coverage(res_poisson_100$stats,width = c(0.95, 0.8, 0.5))
## # A tibble: 3 × 6
## variable width width_represented ci_low estimate ci_high
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 lambda 0.5 0.5 0.411 0.442 0.473
## 2 lambda 0.8 0.8 0.684 0.713 0.740
## 3 lambda 0.95 0.95 0.883 0.903 0.920
If the model is so simple, maybe a simple Laplace approximation
around the posterior mode would suffice? We can use Stan’s
optimizing
mode exactly for that. Although unfortunately,
this is currently implemented only in rstan
and not for
cmdstanr
(because the underlying CmdStan does not expose
the Hessian of the optimizing fit).
So let us build an optimizing backend
model_poisson_rstan <- stan_model("stan/poisson.stan")
backend_poisson_optimizing <- SBC_backend_rstan_optimizing(model_poisson_rstan)
and use it to fit the same datasets - first to the one with
N = 20
.
res_poisson_optimizing <-
compute_SBC(ds_poisson, backend_poisson_optimizing, keep_fits = FALSE,
cache_mode = "results", cache_location = file.path(cache_dir, "poisson_opt"))
## Results loaded from cache file 'poisson_opt'
The resulting ECDF and rank plots are very good.
plot_ecdf_diff(res_poisson_optimizing)
plot_rank_hist(res_poisson_optimizing)
Similarly, we can fit the N = 100
datasets.
res_poisson_optimizing_100 <-
compute_SBC(ds_poisson_100, backend_poisson_optimizing, keep_fits = FALSE,
cache_mode = "results", cache_location = file.path(cache_dir, "poisson_opt_100"))
## Results loaded from cache file 'poisson_opt_100'
The resulting rank plot once again indicates no serious issues and we thus get better results here than with ADVI.
plot_ecdf_diff(res_poisson_optimizing_100)
plot_rank_hist(res_poisson_optimizing_100)
We see that for simple models ADVI can provide very tight approximation to exact inference, but this cannot be taken for granted. Surprisingly, having more data does not make the ADVI approximation necessarily better. Additionally, for such simple models, a simple Laplace approximation around the posterior mode works better (and likely faster) than ADVI.
We’ll jump to a quite more complex model (partially because we wanted to have a HMM example).
In this example, we have collected a set of counts of particles emitted by a specimen in a relatively large number of experimental runs. We however noticed that there is a suspiciously large number of low counts. Inspecting the equipment, it turns out that the experiment was not set up properly and in some of the runs, our detector could only register background noise. We however don’t know which runs were erroneous.
So we assume that some experiments contain both background noise and the signal of interest and the rest contain just the background. For simplicity, we assume a Poisson distribution for the counts.
Additionally, observing background only vs. signal in individual data points is not independent and we want to model how the experimental setup switches between these two states over time. We add additional structure to the model to account for this autocorrelation.
One possible choice for such structure is hidden Markov models (HMMs) where we assume the probability of transitioning from one state to another is identical across all time points. The case study for HMMs has a more thorough discussion and also shows how to code those in Stan.
Maybe the simplest way to describe the model is to show how we simulate the data:
generator_HMM <- function(N) {
mu_background <- rlnorm(1, -2, 1)
mu_signal <- rlnorm(1, 2, 1)
# Draw the transition probabilities
t1 <- MCMCpack::rdirichlet(1, c(3, 3))
t2 <- MCMCpack::rdirichlet(1, c(3, 3))
states = rep(NA_integer_, N)
# Draw from initial state distribution
rho <- MCMCpack::rdirichlet(1, c(1, 10))
# Simulate the hidden states
states[1] = sample(1:2, size = 1, prob = rho)
for(n in 2:length(states)) {
if(states[n - 1] == 1)
states[n] = sample(c(1, 2), size = 1, prob = t1)
else if(states[n - 1] == 2)
states[n] = sample(c(1, 2), size = 1, prob = t2)
}
# Simulate observations given the state
mu <- c(mu_background, mu_background + mu_signal)
y <- rpois(N, mu[states])
list(
variables = list(
mu_background = mu_background,
mu_signal = mu_signal,
# rdirichlet returns matrices, convert to 1D vectors
t1 = as.numeric(t1),
t2 = as.numeric(t2),
rho = as.numeric(rho)
),
generated = list(
N = N,
y = y
)
)
}
And here is the Stan code that models this process (it is based on the example from the HMM case study but simplified and modified).
data {
int N; // Number of observations
array[N] int y;
}
parameters {
// Parameters of measurement model
real<lower=0> mu_background;
real<lower=0> mu_signal;
// Initial state
simplex[2] rho;
// Rows of the transition matrix
simplex[2] t1;
simplex[2] t2;
}
model {
matrix[2, 2] Gamma;
matrix[2, N] log_omega;
// Build the transition matrix
Gamma[1, : ] = t1';
Gamma[2, : ] = t2';
// Compute the log likelihoods in each possible state
for (n in 1 : N) {
// The observation model could change with n, or vary in a number of
// different ways (which is why log_omega is passed in as an argument)
log_omega[1, n] = poisson_lpmf(y[n] | mu_background);
log_omega[2, n] = poisson_lpmf(y[n] | mu_background + mu_signal);
}
mu_background ~ lognormal(-2, 1);
mu_signal ~ lognormal(2, 1);
// Initial state - we're quite sure we started with the source working
rho ~ dirichlet([1, 10]);
t1 ~ dirichlet([3, 3]);
t2 ~ dirichlet([3, 3]);
target += hmm_marginal(log_omega, Gamma, rho);
}
We start with the default (meanfield) variational backend via Stan:
if(package_version(cmdstanr::cmdstan_version()) < package_version("2.26.0") ) {
stop("The models int this section require CmdStan 2.26 or later.")
}
model_HMM <- cmdstan_model("stan/hmm_poisson.stan")
backend_HMM <- SBC_backend_cmdstan_variational(model_HMM, n_retries_init = 3)
Since we are feeling confident that our model is implemented correctly (and the model runs quickly), we start with 100 simulations and assume 100 observations for each. If you are developing a new model, it might be useful to start with fewer simulations, as discussed in the small model workflow vignette.
And we compute results
set.seed(642354822)
ds_hmm <- generate_datasets(SBC_generator_function(generator_HMM, N = 100), n_sims = 100)
## Warning in .recacheSubclasses(def@className, def, env): undefined subclass
## "ndiMatrix" of class "replValueSp"; definition not updated
res_hmm <- compute_SBC(ds_hmm, backend_HMM,
cache_mode = "results", cache_location = file.path(cache_dir, "hmm"))
## Results loaded from cache file 'hmm'
There are not huge problems, but the mu_signal
variable
seems to not be well calibrated:
plot_ecdf_diff(res_hmm)
plot_rank_hist(res_hmm)
We may also look at the observed coverage of central intervals - we
see that for mu_signal
the approximation tends to be
overconfident for the wider intervals.
plot_coverage(res_hmm)
To make sure this is not a fluke we add 400 more simulations.
set.seed(2254355)
ds_hmm_2 <- generate_datasets(SBC_generator_function(generator_HMM, N = 100), n_sims = 400)
res_hmm_2 <- bind_results(
res_hmm,
compute_SBC(ds_hmm_2,backend_HMM,
cache_mode = "results",
cache_location = file.path(cache_dir, "hmm2"))
)
## Results loaded from cache file 'hmm2'
This confirms the problems with mu_signal
. additionally,
we see that mu_background
and the rho
variables also show some irregularities.
plot_ecdf_diff(res_hmm_2)
plot_rank_hist(res_hmm_2)
Looking at the observed coverage, both mu_background
and
mu_signal
are now clearly somewhat overconfident for the
wider intervals.
plot_coverage(res_hmm_2)
This is what we get when we focus on the 90% posterior credible interval:
coverage_hmm <- empirical_coverage(res_hmm_2$stats, width = 0.9)[, c("variable", "ci_low", "ci_high")]
coverage_hmm
## # A tibble: 8 × 3
## variable ci_low ci_high
## <chr> <dbl> <dbl>
## 1 mu_background 0.822 0.884
## 2 mu_signal 0.754 0.825
## 3 rho[1] 0.864 0.918
## 4 rho[2] 0.864 0.918
## 5 t1[1] 0.803 0.868
## 6 t1[2] 0.803 0.868
## 7 t2[1] 0.838 0.897
## 8 t2[2] 0.838 0.897
So the 90% central credible interval for mu_signal
likely contains less than 83% of true values.
For a crude result, the default ADVI setup we just tested is not terrible: we don’t expect to see a strong bias and the model will be somewhat overconfident, but not catastrophically so.
Note that when the user is aiming for a point estimate of mean or other central tendency, a summary of VI posterior may provide a good point estimate even when the uncertainty is miscalibrated. VSBC, a diagnostic that concentrates on bias in marginal quantity, was developed to test this (Yao et. al., 2018), but is currently not implemented in our package (see https://github.com/hyunjimoon/SBC/issues/60 for progress). Other diagnostic such as PSIS-based which is associated with specific data and test quantity, is less flexible for target-testing.
We may try if the situation improves with full-rank ADVI - let’s run it for the same datasets.
ds_hmm_all <- bind_datasets(ds_hmm, ds_hmm_2)
res_hmm_fullrank <- compute_SBC(
ds_hmm_all,
SBC_backend_cmdstan_variational(model_HMM, algorithm = "fullrank", n_retries_init = 3),
cache_mode = "results", cache_location = file.path(cache_dir, "hmm_fullrank"))
## Results loaded from cache file 'hmm_fullrank'
We still have problems, but different ones (and arguably somewhat less severe):
plot_ecdf_diff(res_hmm_fullrank)
plot_rank_hist(res_hmm_fullrank)
Interestingly, the rank plot for mu_signal
shows a
“frowning” shape, meaning the mean-field approximation is slightly
underconfident here.
This is nicely demonstrated by looking at the central interval
coverage - now the coverage of mu_signal
is larger
than it should be, so the model is underconfident (i.e. more
conservative), while the coverages for other variables track the nominal
values quite closely.
plot_coverage(res_hmm_fullrank)
Or alternatively looking at the numerical values for coverage of the central 90% interval
coverage_hmm_fullrank <-
empirical_coverage(res_hmm_fullrank$stats, width = 0.9)[, c("variable", "ci_low", "ci_high")]
coverage_hmm_fullrank
## # A tibble: 8 × 3
## variable ci_low ci_high
## <chr> <dbl> <dbl>
## 1 mu_background 0.851 0.907
## 2 mu_signal 0.891 0.939
## 3 rho[1] 0.875 0.927
## 4 rho[2] 0.875 0.927
## 5 t1[1] 0.868 0.922
## 6 t1[2] 0.868 0.922
## 7 t2[1] 0.879 0.930
## 8 t2[2] 0.879 0.930
This pattern where the default meanfield approximation is overconfident and the fullrank approximation is underconfident is in fact quite frequently seen, which motivated some experiments with a low rank approximation that would fall in between those, but as of early 2022 this is not ready for use in Stan.
In some cases, it might also help to reduce the tolerance
(tol_rel_obj
) of the algorithm. This is a restriction on
evidence lower bound (ELBO) for tighter optimization convergence. Here
we’ll use the default mean-field algorithm, but decrease the
tol_rel_obj
(the default value is 0.01). So let’s try
that.
res_hmm_lowtol <- compute_SBC(
ds_hmm_all,
SBC_backend_cmdstan_variational(model_HMM, tol_rel_obj = 0.001, n_retries_init = 3),
cache_mode = "results", cache_location = file.path(cache_dir, "hmm_lowtol"))
## Results loaded from cache file 'hmm_lowtol'
## - 14 (3%) of fits did not converge.
## Not all diagnostics are OK.
## You can learn more by inspecting $default_diagnostics, $backend_diagnostics
## and/or investigating $outputs/$messages/$warnings for detailed output from the backend.
Reducing tolerance leads to a small proportion of non-converging
fits. In theory, increasing grad_samples
improve
non-convergence but in our experience, current ADVI (2021) convergence
does not easily change with this adjustment. Also, since the
non-converged cases are relatively rare, we’ll just remove the
non-converging fits from the SBC results (this is OK as long as we would
discard non-converging fits for real data, see the rejection
sampling vignette).
res_hmm_lowtol_conv <-
res_hmm_lowtol[res_hmm_lowtol$backend_diagnostics$elbo_converged]
plot_ecdf_diff(res_hmm_lowtol_conv)
plot_rank_hist(res_hmm_lowtol_conv)
The problems seem to have become even less pronounced. We may once again inspect the observed coverage of central intervals
plot_coverage(res_hmm_lowtol_conv)
and the numerical values for the coverage of the central 90% interval.
empirical_coverage(res_hmm_lowtol$stats, width = 0.9)[, c("variable", "ci_low", "ci_high")]
## # A tibble: 8 × 3
## variable ci_low ci_high
## <chr> <dbl> <dbl>
## 1 mu_background 0.827 0.888
## 2 mu_signal 0.831 0.891
## 3 rho[1] 0.882 0.932
## 4 rho[2] 0.882 0.932
## 5 t1[1] 0.814 0.877
## 6 t1[2] 0.814 0.877
## 7 t2[1] 0.844 0.902
## 8 t2[2] 0.844 0.902
This variant has somewhat lower overall mismatch, but tends to be overconfident, which might in some cases be less desirable than the more conservative fullrank.
Would optimizing provide sensible results in this case? We build an optimizng backend and run it.
SBC:::require_package_version("rstan", "2.26", "The models in the following sections need more recent rstan than what is available on CRAN - use https://mc-stan.org/r-packages/ to get it")
model_HMM_rstan <- stan_model("stan/hmm_poisson.stan")
res_hmm_optimizing <- compute_SBC(
ds_hmm_all,
SBC_backend_rstan_optimizing(model_HMM_rstan, n_retries_hessian = 3),
cache_mode = "results", cache_location = file.path(cache_dir, "hmm_optimizing"))
## Results loaded from cache file 'hmm_optimizing'
## - 1 (0%) of fits required multiple attempts to produce usable Hessian.
## Not all diagnostics are OK.
## You can learn more by inspecting $default_diagnostics, $backend_diagnostics
## and/or investigating $outputs/$messages/$warnings for detailed output from the backend.
We see that while for some variables (mu_signal
, the
transition probabilities t[]
), the Laplace approximation is
reasonably well calibrated, it is very badly calibrated with respect to
the initial states rho
and also for
mu_background
, where there is substantial bias. So if we
were only interested in a subset of the variables, the optimizing fit
could still be on OK choice.
plot_ecdf_diff(res_hmm_optimizing)
plot_rank_hist(res_hmm_optimizing)
To summarise, the HMM model turns out to pose minor problems for ADVI that can be partially resolved by tweaking the parameters of the ADVI algorithm. Just using optimizing results in much worse calibration than ADVI.
Another relevant question is how much speed we gained. To have a comparison, we run full MCMC with Stan for the same datasets.
res_hmm_sample <- compute_SBC(
ds_hmm[1:50],
SBC_backend_cmdstan_sample(model_HMM),
keep_fits = FALSE,
cache_mode = "results", cache_location = file.path(cache_dir, "hmm_sample"))
## Results loaded from cache file 'hmm_sample'
## - 1 (2%) fits had at least one Rhat > 1.01. Largest Rhat was 1.019.
## - 1 (2%) fits had tail ESS undefined or less than half of the maximum rank, potentially skewing
## the rank statistics. The lowest tail ESS was 154.
## If the fits look good otherwise, increasing `thin_ranks` (via recompute_SBC_statistics)
## or number of posterior draws (by refitting) might help.
## - 1 (2%) fits had divergent transitions. Maximum number of divergences was 51.
## Not all diagnostics are OK.
## You can learn more by inspecting $default_diagnostics, $backend_diagnostics
## and/or investigating $outputs/$messages/$warnings for detailed output from the backend.
We get a small number of problematic fits, which we will ignore for now. We check that there are no obvious calibration problems:
plot_ecdf_diff(res_hmm_sample)
plot_rank_hist(res_hmm_sample)
For the machine we built the vignette on, here are the distributions of times (for ADVI and optimizing) and time of longest chain (for HMC):
hmm_time <-
rbind(
data.frame(alg = "Optimizing",
time = res_hmm_optimizing$backend_diagnostics$time),
data.frame(alg = "Meanfield",
time = res_hmm$backend_diagnostics$time),
data.frame(alg = "Fullrank",
time = res_hmm_fullrank$backend_diagnostics$time),
data.frame(alg = "Meanfield + low tol.",
time = res_hmm_lowtol$backend_diagnostics$time),
data.frame(alg = "Sampling (longest chain)",
time = res_hmm_sample$backend_diagnostics$max_chain_time))
max_time_optimizing <- round(max(res_hmm_optimizing$backend_diagnostics$time), 2)
hmm_time$alg <- factor(hmm_time$alg,
levels = c("Optimizing",
"Meanfield",
"Fullrank",
"Meanfield + low tol.",
"Sampling (longest chain)"))
ggplot(hmm_time, aes(x = time)) +
geom_histogram(aes(y = after_stat(density)), bins = 20) +
facet_wrap(~alg, ncol = 1) +
scale_x_continuous("Time [seconds]")
Depressingly, while using lower tolerance let us get almost as good uncertainty quantification as sampling, it also erased a big part of the performance advantage variational inference had over sampling for this model. However, both the fullrank and meanfield approximations provide not-terrible estimates and are noticeably faster than sampling. Optimizing is by far the fastest as the longest time observed is just 1.58 seconds.
Unforutnately, ADVI as implemented in Stan can be quite fragile. Let
us consider a very small change to the HMM model from the previous
section: let us model the means of the counts for the two states
directly (the previous version modelled the background state and the
difference between the two states) and move to the log scale. So instead
of mu_background
and mu_signal
we have an
ordered
vector log_mu
:
data {
int N; // Number of observations
array[N] int y;
}
parameters {
// Parameters of measurement model
ordered[2] log_mu;
// Initial state
simplex[2] rho;
// Rows of the transition matrix
simplex[2] t1;
simplex[2] t2;
}
model {
matrix[2, 2] Gamma;
matrix[2, N] log_omega;
// Build the transition matrix
Gamma[1, : ] = t1';
Gamma[2, : ] = t2';
// Compute the log likelihoods in each possible state
for (n in 1 : N) {
// The observation model could change with n, or vary in a number of
// different ways (which is why log_omega is passed in as an argument)
log_omega[1, n] = poisson_log_lpmf(y[n] | log_mu[1]);
log_omega[2, n] = poisson_log_lpmf(y[n] | log_mu[2]);
}
log_mu[1] ~ normal(-2, 1);
log_mu[2] ~ normal(2, 1);
// Initial state - we're quite sure we started with the source working
rho ~ dirichlet([1, 10]);
t1 ~ dirichlet([3, 3]);
t2 ~ dirichlet([3, 3]);
target += hmm_marginal(log_omega, Gamma, rho);
}
generated quantities {
positive_ordered[2] mu = exp(log_mu);
}
This model is almost identical - in theory the only difference is that it implies a slightly different prior on the active (higher mean) state. Here is how we can generate data with this mildly different prior (we need rejection sampling to fulfill the ordering constraint):
generator_HMM_ordered <- function(N) {
# Rejection sampling for ordered mu with the correct priors
repeat {
log_mu <- c(rnorm(1, -2, 1), rnorm(1, 2, 1))
if(log_mu[1] < log_mu[2]) {
break;
}
}
mu <- exp(log_mu)
# Draw the transition probabilities
t1 <- MCMCpack::rdirichlet(1, c(3, 3))
t2 <- MCMCpack::rdirichlet(1, c(3, 3))
states = rep(NA_integer_, N)
# Draw from initial state distribution
rho <- MCMCpack::rdirichlet(1, c(1, 10))
states[1] = sample(1:2, size = 1, prob = rho)
for(n in 2:length(states)) {
if(states[n - 1] == 1)
states[n] = sample(c(1, 2), size = 1, prob = t1)
else if(states[n - 1] == 2)
states[n] = sample(c(1, 2), size = 1, prob = t2)
}
y <- rpois(N, mu[states])
list(
variables = list(
log_mu = log_mu,
# rdirichlet returns matrices, convert to 1D vectors
t1 = as.numeric(t1),
t2 = as.numeric(t2),
rho = as.numeric(rho)
),
generated = list(
N = N,
y = y
)
)
}
So let us build a default variational backend and fit it to just 20 simulations.
model_HMM_ordered <- cmdstan_model("stan/hmm_poisson_ordered.stan")
backend_HMM_ordered <- SBC_backend_cmdstan_variational(model_HMM_ordered, n_retries_init = 3)
set.seed(12333654)
ds_hmm_ordered <- generate_datasets(
SBC_generator_function(generator_HMM_ordered, N = 100),
n_sims = 20)
res_hmm_ordered <-
compute_SBC(ds_hmm_ordered, backend_HMM_ordered,
cache_mode = "results", cache_location = file.path(cache_dir, "hmm_ordered"))
## Results loaded from cache file 'hmm_ordered'
Immediately we see that the log_mu[1]
variable is
heavily miscalibrated.
plot_ecdf_diff(res_hmm_ordered)
plot_rank_hist(res_hmm_ordered)
What changed? To understand that we need to remember how Stan represents
constrained data types. In short, in the model in Example II, Stan
will internally work with the so called unconstrained
parameters mu_background__ = log(mu_background)
and
mu_signal__ = log(mu_signal)
. In this modified model, the
internal representation will be: log_mu_1__ = log_mu[1]
(without any change) and
log_mu_2__ = log(log_mu[2] - log_mu[1])
. So the mean for
the active component is actually
exp(log_mu_1__ + exp(log_mu_2__))
. This then introduces a
complex correlation structure between the unconstrained parameters that
the ADVI algorithm is unable to handle well.
Even trying the fullrank variant does not help:
backend_HMM_ordered_fullrank <-
SBC_backend_cmdstan_variational(model_HMM_ordered,
algorithm = "fullrank", n_retries_init = 3)
res_hmm_ordered_fullrank <-
compute_SBC(ds_hmm_ordered, backend_HMM_ordered,
cache_mode = "results", cache_location = file.path(cache_dir, "hmm_ordered_fullrank"))
## Results loaded from cache file 'hmm_ordered_fullrank'
The results are still strongly miscalibrated.
plot_ecdf_diff(res_hmm_ordered_fullrank)
plot_rank_hist(res_hmm_ordered_fullrank)
To have a complete overview we may also try the optimizing fit:
model_HMM_ordered_rstan <- stan_model("stan/hmm_poisson_ordered.stan")
res_hmm_ordered_optimizing <- compute_SBC(
ds_hmm_ordered,
SBC_backend_rstan_optimizing(model_HMM_ordered_rstan),
cache_mode = "results", cache_location = file.path(cache_dir, "hmm_ordered_optimizing"))
## Results loaded from cache file 'hmm_ordered_optimizing'
in this case, optimizing has better calibration for
log_mu
, but worse calibration for rho
than
ADVI.
plot_ecdf_diff(res_hmm_ordered_optimizing)
plot_rank_hist(res_hmm_ordered_optimizing)
As this vignette has shown, for some models, ADVI will provide results that are close to what we get with sampling, but it may also fail catastrophically on models that are just slightly different. Tweaking the algorithm parameters might also be necessary. For some cases where ADVI works, the Laplace approximation with optimizing will also work well. ADVI (and optimizng) cannot thus be used blindly. Fortunately SBC can be used to check against this type of problem without ever needing to run the full sampling.
In computational_algorithm2, we will focus on hopeful aspects of approximate computation. The adversarial relation between computation and diagnostic is introduced based on which mutual evolvement happens. This can give insight to computational algorithm designers aiming to pass SBC. For illustration, when and how VI can be used is discussed which include customized SBC (e.g. VSBC) and first or second-order correction.