compose_data
to prepare a data frame for the modelspread_draws
point_interval
functions:
[median|mean|mode]_[qi|hdi]
geom_pointinterval
stat_eye
stat_slabinterval
.width =
argumentgather_draws
and gather_variables
This vignette introduces the tidybayes
package, which
facilitates the use of tidy data (one observation per row) with Bayesian
models in R. This vignette is geared towards working with tidy data in
general-purpose modeling functions like JAGS or Stan. For a similar
introduction to the use of tidybayes
with high-level
modeling functions such as those in brms
or
rstanarm
, see vignette("tidy-brms")
or
vignette("tidy-rstanarm")
. This vignette also describes how
to use ggdist
(the sister package to
tidybayes
) for visualizing model output.
The default output (and sometimes input) data formats of popular
modeling functions like JAGS and Stan often don’t quite conform to the
ideal of tidy
data. For example, input formats might expect a list instead of a
data frame, and for all variables to be encoded as numeric values
(requiring translation of factors to numeric values and the creation of
index variables to store the number of levels per factor or the number
of observations in a data frame). Output formats will often be in matrix
form (requiring conversion for use with libraries like ggplot), and will
use numeric indices (requiring conversion back into factor level names
if the you wish to make meaningfully-labeled plots or tables).
tidybayes
automates all of these sorts of tasks.
There are a few core ideas that run through the
tidybayes
API that should (hopefully) make it easy to
use:
Tidy data does not always mean all parameter names as
values. In contrast to the ggmcmc
library (which
translates model results into a data frame with a Parameter
and value
column), the spread_draws
function
in tidybayes
produces data frames where the columns are
named after parameters and (in some cases) indices of those parameters,
as automatically as possible and using a syntax as close to the same way
you would refer to those variables in the model’s language as possible.
A similar function to ggmcmc
’s approach is also provided in
gather_draws
, since sometimes you do want variable
names as values in a column. The goal is for tidybayes
to
do the tedious work of figuring out how to make a data frame look the
way you need it to, including turning parameters with indices like
"b[1,2]"
and the like into tidy data for you.
Fit into the tidyverse. tidybayes
methods fit into a workflow familiar to users of the
tidyverse
(dplyr
, tidyr
,
ggplot2
, etc), which means fitting into the pipe
(%>%
) workflow, using and respecting grouped data frames
(thus spread_draws
and gather_draws
return
results already grouped by variable indices, and methods like
median_qi
calculate point summaries and intervals for
variables and groups simultaneously), and not reinventing too much of
the wheel if it is already made easy by functions provided by existing
tidyverse
packages (unless it makes for much clearer code
for a common idiom). For compatibility with other package column names
(such as broom::tidy
), tidybayes
provides
transformation functions like to_broom_names
that can be
dropped directly into data transformation pipelines.
Focus on composable operations and plotting primitives,
not monolithic plots and operations. Several other packages
(notably bayesplot
and ggmcmc
) already provide
an excellent variety of pre-made methods for plotting Bayesian results.
tidybayes
shies away from duplicating this functionality.
Instead, it focuses on providing composable operations for generating
and manipulating Bayesian samples in a tidy data format, and graphical
primitives for ggplot
that allow you to build custom plots
easily. Most simply, where bayesplot
and
ggmcmc
tend to have functions with many options that return
a full ggplot object, tidybayes
tends towards providing
primitives (like geom
s) that you can compose and combine
into your own custom plots. I believe both approaches have their place:
pre-made functions are especially useful for common, quick operations
that don’t need customization (like many diagnostic plots), while
composable operations tend to be useful for more complex custom plots
(in my
opinion).
Sensible defaults make life easy. But options (and the data being tidy in the first place) make it easy to go your own way when you need to.
Variable names in models should be descriptive, not
cryptic. This principle implies avoiding cryptic (and short)
subscripts in favor of longer (but descriptive) ones. This is a matter
of readability and accessibility of models to others. For example, a
common pattern among Stan users (and in the Stan manual) is to use
variables like J
to refer to the number of elements in a
group (e.g., number of participants) and a corresponding index like
j
to refer to specific elements in that group. I believe
this sacrifices too much readability for the sake of concision; I prefer
a pattern like n_participant
for the size of the group and
participant
(or a mnemonic short form like p
)
for specific elements. In functions where names are auto-generated (like
compose_data
), tidybayes
will (by default)
assume you want these sorts of more descriptive names; however, you can
always override the default naming scheme.
tidybayes
aims to support a variety of models with a
uniform interface. Currently supported models include rstan, cmdstanr, brms, rstanarm, runjags, rjags, jagsUI, coda::mcmc and
coda::mcmc.list, posterior::draws, MCMCglmm, and
anything with its own as.mcmc.list
implementation. If you
install the tidybayes.rethinking
package, models from the rethinking package
are also supported.
For an up-to-date list of supported models, see
?"tidybayes-models"
.
The following libraries are required to run this vignette:
library(magrittr)
library(dplyr)
library(forcats)
library(modelr)
library(ggdist)
library(tidybayes)
library(ggplot2)
library(cowplot)
library(broom)
library(rstan)
library(rstanarm)
library(brms)
library(bayesplot)
library(RColorBrewer)
theme_set(theme_tidybayes() + panel_border())
These options help Stan run faster:
To demonstrate tidybayes
, we will use a simple dataset
with 10 observations from 5 conditions each:
set.seed(5)
n = 10
n_condition = 5
ABC =
tibble(
condition = factor(rep(c("A","B","C","D","E"), n)),
response = rnorm(n * 5, c(0,1,2,1,-1), 0.5)
)
A snapshot of the data looks like this:
condition | response |
---|---|
A | -0.4204277 |
B | 1.6921797 |
C | 1.3722541 |
D | 1.0350714 |
E | -0.1442796 |
A | -0.3014540 |
B | 0.7639168 |
C | 1.6823143 |
D | 0.8571132 |
E | -0.9309459 |
This is a typical tidy format data frame: one observation per row. Graphically:
compose_data
to prepare a data frame for the
modelShunting data from a data frame into a format usable in samplers like
JAGS or Stan can involve a tedious set of operations, like generating
index variables storing the number of operations or the number of levels
in a factor. compose_data
automates these operations.
A hierarchical model of our example data might fit an overall mean
across the conditions (overall_mean
), the standard
deviation of the condition means (condition_mean_sd
), the
mean within each condition (condition_mean[condition]
) and
the standard deviation of the responses given a condition mean
(response_sd
):
data {
int<lower=1> n;
int<lower=1> n_condition;
int<lower=1, upper=n_condition> condition[n];
real response[n];
}
parameters {
real overall_mean;
vector[n_condition] condition_zoffset;
real<lower=0> response_sd;
real<lower=0> condition_mean_sd;
}
transformed parameters {
vector[n_condition] condition_mean;
condition_mean = overall_mean + condition_zoffset * condition_mean_sd;
}
model {
response_sd ~ cauchy(0, 1); // => half-cauchy(0, 1)
condition_mean_sd ~ cauchy(0, 1); // => half-cauchy(0, 1)
overall_mean ~ normal(0, 5);
condition_zoffset ~ normal(0, 1); // => condition_mean ~ normal(overall_mean, condition_mean_sd)
for (i in 1:n) {
response[i] ~ normal(condition_mean[condition[i]], response_sd);
}
}
We have compiled and loaded this model into the variable
ABC_stan
.
This model expects these variables as input:
n
: number of observationsn_condition
: number of conditionscondition
: a vector of integers indicating the
condition of each observationresponse
: a vector of observationsOur data frame (ABC
) only has response
and
condition
, and condition
is in the wrong
format (it is a factor instead of numeric). However,
compose_data
can generate a list containing the above
variables in the correct format automatically. It recognizes that
condition
is a factor and converts it to a numeric, adds
the n_condition
variable automatically containing the
number of levels in condition
, and adds the n
column containing the number of observations (number of rows in the data
frame):
## $condition
## [1] 1 2 3 4 5 1 2 3 4 5 1 2 3 4 5 1 2 3 4 5 1 2 3 4 5 1 2 3 4 5 1 2 3 4 5 1 2 3 4 5 1 2 3 4 5 1 2 3 4 5
##
## $n_condition
## [1] 5
##
## $response
## [1] -0.42042774 1.69217967 1.37225407 1.03507138 -0.14427956 -0.30145399 0.76391681 1.68231434 0.85711318
## [10] -0.93094589 0.61381517 0.59911027 1.45980370 0.92123282 -1.53588002 -0.06949307 0.70134345 0.90801662
## [19] 1.12040863 -1.12967770 0.45025597 1.47093470 2.73398095 1.35338054 -0.59049553 -0.14674092 1.70929454
## [28] 2.74938691 0.67145895 -1.42639772 0.15795752 1.55484708 3.10773029 1.60855182 -0.26038911 0.47578692
## [37] 0.49523368 0.99976363 0.11890706 -1.07130406 0.77503018 0.59878841 1.96271054 1.94783398 -1.22828447
## [46] 0.28111168 0.55649574 1.76987771 0.63783576 -1.03460558
##
## $n
## [1] 50
This makes it easy to skip right to running the model without munging the data yourself:
The results look like this:
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
## overall_mean 0.63 0.02 0.62 -0.55 0.28 0.63 0.95 1.87 915 1
## condition_zoffset[1] -0.39 0.02 0.49 -1.41 -0.71 -0.39 -0.05 0.58 973 1
## condition_zoffset[2] 0.35 0.02 0.49 -0.62 0.04 0.33 0.66 1.31 902 1
## condition_zoffset[3] 1.11 0.02 0.59 -0.02 0.70 1.11 1.50 2.30 868 1
## condition_zoffset[4] 0.37 0.02 0.49 -0.62 0.04 0.36 0.69 1.31 905 1
## condition_zoffset[5] -1.39 0.02 0.65 -2.74 -1.81 -1.36 -0.94 -0.23 1031 1
## response_sd 0.56 0.00 0.06 0.46 0.52 0.56 0.60 0.70 1793 1
## condition_mean_sd 1.23 0.02 0.51 0.62 0.90 1.12 1.43 2.50 997 1
## condition_mean[1] 0.20 0.00 0.17 -0.14 0.09 0.20 0.32 0.55 4885 1
## condition_mean[2] 1.00 0.00 0.18 0.65 0.89 1.01 1.12 1.34 4606 1
## condition_mean[3] 1.84 0.00 0.18 1.48 1.72 1.84 1.95 2.19 4893 1
## condition_mean[4] 1.02 0.00 0.18 0.68 0.90 1.02 1.14 1.37 4274 1
## condition_mean[5] -0.89 0.00 0.18 -1.23 -1.01 -0.89 -0.77 -0.53 4725 1
## lp__ 0.30 0.08 2.34 -5.02 -1.08 0.63 2.06 3.79 915 1
##
## Samples were drawn using NUTS(diag_e) at Sun Sep 15 00:33:21 2024.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
spread_draws
Now that we have our results, the fun begins: getting the draws out in a tidy format! The default methods in Stan for extracting draws from the model do so in a nested format:
## List of 6
## $ overall_mean : num [1:4000(1d)] 1.113 0.818 0.432 0.794 0.528 ...
## ..- attr(*, "dimnames")=List of 1
## .. ..$ iterations: NULL
## $ condition_zoffset: num [1:4000, 1:5] -0.9652 -0.4328 -0.0276 -0.4277 -0.0597 ...
## ..- attr(*, "dimnames")=List of 2
## .. ..$ iterations: NULL
## .. ..$ : NULL
## $ response_sd : num [1:4000(1d)] 0.696 0.587 0.651 0.571 0.566 ...
## ..- attr(*, "dimnames")=List of 1
## .. ..$ iterations: NULL
## $ condition_mean_sd: num [1:4000(1d)] 0.974 1.068 0.53 1.456 0.951 ...
## ..- attr(*, "dimnames")=List of 1
## .. ..$ iterations: NULL
## $ condition_mean : num [1:4000, 1:5] 0.173 0.356 0.417 0.171 0.472 ...
## ..- attr(*, "dimnames")=List of 2
## .. ..$ iterations: NULL
## .. ..$ : NULL
## $ lp__ : num [1:4000(1d)] -1.598 0.887 -5.352 3.171 0.319 ...
## ..- attr(*, "dimnames")=List of 1
## .. ..$ iterations: NULL
There are also methods for extracting draws as matrices or data frames in Stan (and other model types, such as JAGS and MCMCglmm, have their own formats).
The spread_draws
method yields a common format for all
model types supported by tidybayes
. It lets us instead
extract draws into a data frame in tidy format, with a
.chain
and .iteration
column storing the chain
and iteration for each row (if available), a .draw
column
that uniquely indexes each draw, and the remaining columns corresponding
to model variables or variable indices. The spread_draws
method accepts any number of column specifications, which can include
names for variables and names for variable indices. For example, we can
extract the condition_mean
variable as a tidy data frame,
and put the value of its first (and only) index into the
condition
column, using a syntax that directly echoes how
we would specify indices of the condition_mean
variable in
the model itself:
condition | condition_mean | .chain | .iteration | .draw |
---|---|---|---|---|
1 | 0.0054368 | 1 | 1 | 1 |
1 | -0.0835896 | 1 | 2 | 2 |
1 | 0.0324232 | 1 | 3 | 3 |
1 | 0.1126821 | 1 | 4 | 4 |
1 | 0.1567650 | 1 | 5 | 5 |
1 | 0.2184783 | 1 | 6 | 6 |
1 | 0.2759586 | 1 | 7 | 7 |
1 | 0.0130420 | 1 | 8 | 8 |
1 | 0.1523690 | 1 | 9 | 9 |
1 | 0.1918692 | 1 | 10 | 10 |
As-is, the resulting variables don’t know anything about where their
indices came from. The index of the condition_mean
variable
was originally derived from the condition
factor in the
ABC
data frame. But Stan doesn’t know this: it is just a
numeric index to Stan, so the condition
column just
contains numbers (1, 2, 3, 4, 5
) instead of the factor
levels these numbers correspond to
("A", "B", "C", "D", "E"
).
We can recover this missing type information by passing the model
through recover_types
before using
spread_draws
. In itself recover_types
just
returns a copy of the model, with some additional attributes that store
the type information from the data frame (or other objects) that you
pass to it. This doesn’t have any useful effect by itself, but functions
like spread_draws
use this information to convert any
column or index back into the data type of the column with the same name
in the original data frame. In this example, spread_draws
recognizes that the condition
column was a factor with five
levels ("A", "B", "C", "D", "E"
) in the original data
frame, and automatically converts it back into a factor:
condition | condition_mean | .chain | .iteration | .draw |
---|---|---|---|---|
A | 0.0054368 | 1 | 1 | 1 |
A | -0.0835896 | 1 | 2 | 2 |
A | 0.0324232 | 1 | 3 | 3 |
A | 0.1126821 | 1 | 4 | 4 |
A | 0.1567650 | 1 | 5 | 5 |
A | 0.2184783 | 1 | 6 | 6 |
A | 0.2759586 | 1 | 7 | 7 |
A | 0.0130420 | 1 | 8 | 8 |
A | 0.1523690 | 1 | 9 | 9 |
A | 0.1918692 | 1 | 10 | 10 |
Because we often want to make multiple separate calls to
spread_draws
, it is often convenient to decorate the
original model using recover_types
immediately after it has
been fit, so we only have to call it once:
Now we can omit the recover_types
call before subsequent
calls to spread_draws
.
point_interval
functions: [median|mean|mode]_[qi|hdi]
tidybayes
provides a family of functions for generating
point summaries and intervals from draws in a tidy format. These
functions follow the naming scheme
[median|mean|mode]_[qi|hdi]
, for example,
median_qi
, mean_qi
, mode_hdi
, and
so on. The first name (before the _
) indicates the type of
point summary, and the second name indicates the type of interval.
qi
yields a quantile interval (a.k.a. equi-tailed interval,
central interval, or percentile interval) and hdi
yields a
highest density interval. Custom point or interval functions can also be
applied using the point_interval
function.
For example, we might extract the draws corresponding to the overall mean and standard deviation of observations:
.chain | .iteration | .draw | overall_mean | response_sd |
---|---|---|---|---|
1 | 1 | 1 | 0.0672484 | 0.5755117 |
1 | 2 | 2 | 0.0361190 | 0.5763601 |
1 | 3 | 3 | 1.1679039 | 0.5505738 |
1 | 4 | 4 | 0.3780767 | 0.5763916 |
1 | 5 | 5 | 0.3593583 | 0.5834732 |
1 | 6 | 6 | 0.3319676 | 0.6210914 |
1 | 7 | 7 | 0.2260568 | 0.6407225 |
1 | 8 | 8 | 0.1074734 | 0.6370637 |
1 | 9 | 9 | 0.2249078 | 0.6087050 |
1 | 10 | 10 | -0.0395213 | 0.5211262 |
Like with condition_mean[condition]
, this gives us a
tidy data frame. If we want the median and 95% quantile interval of the
variables, we can apply median_qi
:
overall_mean | overall_mean.lower | overall_mean.upper | response_sd | response_sd.lower | response_sd.upper | .width | .point | .interval |
---|---|---|---|---|---|---|---|---|
0.6331377 | -0.5458905 | 1.868912 | 0.5575944 | 0.455513 | 0.6976528 | 0.95 | median | qi |
median_qi
summarizes each input column using its median.
If there are multiple columns to summarize, each gets its own
x.lower
and x.upper
column (for each column
x
) corresponding to the bounds of the .width
%
interval. If there is only one column, the names .lower
and
.upper
are used for the interval bounds.
We can specify the columns we want to get medians and intervals from,
as above, or if we omit the list of columns, median_qi
will
use every column that is not a grouping column or a special column (like
.chain
, .iteration
, or .draw
).
Thus in the above example, overall_mean
and
response_sd
are redundant arguments to
median_qi
because they are also the only columns we
gathered from the model. So we can simplify the previous code to the
following:
overall_mean | overall_mean.lower | overall_mean.upper | response_sd | response_sd.lower | response_sd.upper | .width | .point | .interval |
---|---|---|---|---|---|---|---|---|
0.6331377 | -0.5458905 | 1.868912 | 0.5575944 | 0.455513 | 0.6976528 | 0.95 | median | qi |
When we have a variable with one or more indices, such as
condition_mean
, we can apply median_qi
(or
other functions in the point_interval
family) as we did
before:
condition | condition_mean | .lower | .upper | .width | .point | .interval |
---|---|---|---|---|---|---|
A | 0.1988803 | -0.1422924 | 0.5485277 | 0.95 | median | qi |
B | 1.0064882 | 0.6510522 | 1.3407107 | 0.95 | median | qi |
C | 1.8410297 | 1.4791098 | 2.1868748 | 0.95 | median | qi |
D | 1.0221440 | 0.6812030 | 1.3665086 | 0.95 | median | qi |
E | -0.8897187 | -1.2326187 | -0.5288822 | 0.95 | median | qi |
How did median_qi
know what to aggregate? Data frames
returned by spread_draws
are automatically grouped by all
index variables you pass to it; in this case, that means it groups by
condition
. median_qi
respects groups, and
calculates the point summaries and intervals within all groups. Then,
because no columns were passed to median_qi
, it acts on the
only non-special (.
-prefixed) and non-group column,
condition_mean
. So the above shortened syntax is equivalent
to this more verbose call:
m %>%
spread_draws(condition_mean[condition]) %>%
group_by(condition) %>% # this line not necessary (done automatically by spread_draws)
median_qi(condition_mean)
condition | condition_mean | .lower | .upper | .width | .point | .interval |
---|---|---|---|---|---|---|
A | 0.1988803 | -0.1422924 | 0.5485277 | 0.95 | median | qi |
B | 1.0064882 | 0.6510522 | 1.3407107 | 0.95 | median | qi |
C | 1.8410297 | 1.4791098 | 2.1868748 | 0.95 | median | qi |
D | 1.0221440 | 0.6812030 | 1.3665086 | 0.95 | median | qi |
E | -0.8897187 | -1.2326187 | -0.5288822 | 0.95 | median | qi |
When given only a single column, median_qi
will use the
names .lower
and .upper
for the lower and
upper ends of the intervals.
tidybayes
also provides an implementation of
posterior::summarise_draws()
for grouped data frames
(tidybayes::summaries_draws.grouped_df()
), which you can
use to quickly get convergence diagnostics:
condition | variable | mean | median | sd | mad | q5 | q95 | rhat | ess_bulk | ess_tail |
---|---|---|---|---|---|---|---|---|---|---|
A | condition_mean | 0.2012402 | 0.1988803 | 0.1721495 | 0.1705762 | -0.0846900 | 0.4813889 | 1.001423 | 4779.163 | 3171.554 |
B | condition_mean | 1.0026959 | 1.0064882 | 0.1761769 | 0.1728642 | 0.7107108 | 1.2820906 | 0.999899 | 4659.685 | 3491.971 |
C | condition_mean | 1.8363327 | 1.8410297 | 0.1781378 | 0.1744494 | 1.5383924 | 2.1243495 | 1.000396 | 4794.242 | 3600.513 |
D | condition_mean | 1.0221078 | 1.0221440 | 0.1766632 | 0.1777818 | 0.7380914 | 1.3121048 | 1.000279 | 4302.768 | 3059.668 |
E | condition_mean | -0.8876552 | -0.8897187 | 0.1796980 | 0.1772229 | -1.1766026 | -0.5933931 | 1.002886 | 4757.787 | 3247.647 |
geom_pointinterval
Plotting medians and intervals is straightforward using
ggdist::geom_pointinterval()
or
ggdist::stat_pointinterval()
, which are similar to
ggplot2::geom_pointrange()
but with sensible defaults for
multiple intervals. For example:
m %>%
spread_draws(condition_mean[condition]) %>%
ggplot(aes(y = fct_rev(condition), x = condition_mean)) +
stat_pointinterval()
These functions have .width = c(.66, .95)
by default
(showing 66% and 95% intervals), but this can be changed by passing a
.width
argument to
ggdist::stat_pointinterval()
.
stat_eye
The ggdist::stat_halfeye()
geom provides a shortcut to
generating “half-eye plots” (combinations of intervals and densities).
This example also demonstrates how to change the interval probability
(here, to 90% and 50% intervals):
m %>%
spread_draws(condition_mean[condition]) %>%
ggplot(aes(y = fct_rev(condition), x = condition_mean)) +
stat_halfeye(.width = c(.90, .5))
Or say you want to annotate portions of the densities in color; the
fill
aesthetic can vary within a slab in all geoms and
stats in the ggdist::geom_slabinterval()
family, including
ggdist::stat_halfeye()
. For example, if you want to
annotate a domain-specific region of practical equivalence (ROPE), you
could do something like this:
stat_slabinterval
There are a variety of additional stats for visualizing distributions
in the ggdist::geom_slabinterval()
family of stats and
geoms:
See vignette("slabinterval", package = "ggdist")
for an
overview.
.width =
argumentIf you wish to summarise the data before plotting (sometimes useful
for large samples), median_qi()
and its sister functions
can also produce an arbitrary number of probability intervals by setting
the .width =
argument:
condition | condition_mean | .lower | .upper | .width | .point | .interval |
---|---|---|---|---|---|---|
A | 0.1988803 | -0.1422924 | 0.5485277 | 0.95 | median | qi |
B | 1.0064882 | 0.6510522 | 1.3407107 | 0.95 | median | qi |
C | 1.8410297 | 1.4791098 | 2.1868748 | 0.95 | median | qi |
D | 1.0221440 | 0.6812030 | 1.3665086 | 0.95 | median | qi |
E | -0.8897187 | -1.2326187 | -0.5288822 | 0.95 | median | qi |
A | 0.1988803 | -0.0145172 | 0.4205475 | 0.80 | median | qi |
B | 1.0064882 | 0.7766536 | 1.2217134 | 0.80 | median | qi |
C | 1.8410297 | 1.6051066 | 2.0628260 | 0.80 | median | qi |
D | 1.0221440 | 0.7967317 | 1.2450835 | 0.80 | median | qi |
E | -0.8897187 | -1.1144668 | -0.6581584 | 0.80 | median | qi |
A | 0.1988803 | 0.0869769 | 0.3168518 | 0.50 | median | qi |
B | 1.0064882 | 0.8873874 | 1.1205403 | 0.50 | median | qi |
C | 1.8410297 | 1.7192788 | 1.9549433 | 0.50 | median | qi |
D | 1.0221440 | 0.9039277 | 1.1426405 | 0.50 | median | qi |
E | -0.8897187 | -1.0053718 | -0.7654213 | 0.50 | median | qi |
The results are in a tidy format: one row per index
(condition
) and probability level (.width
).
This facilitates plotting. For example, assigning -.width
to the linewidth
aesthetic will show all intervals, making
thicker lines correspond to smaller intervals:
m %>%
spread_draws(condition_mean[condition]) %>%
median_qi(.width = c(.95, .66)) %>%
ggplot(aes(
y = fct_rev(condition), x = condition_mean, xmin = .lower, xmax = .upper,
# size = -.width means smaller probability interval => thicker line
# this can be omitted, geom_pointinterval includes it automatically
# if a .width column is in the input data.
linewidth = -.width
)) +
geom_pointinterval()
ggdist::geom_pointinterval()
includes
size = -.width
as a default aesthetic mapping to facilitate
exactly this usage.
Intervals are nice if the alpha level happens to line up with whatever decision you are trying to make, but getting a shape of the posterior is better (hence eye plots, above). On the other hand, making inferences from density plots is imprecise (estimating the area of one shape as a proportion of another is a hard perceptual task). Reasoning about probability in frequency formats is easier, motivating quantile dotplots (Kay et al. 2016, Fernandes et al. 2018), which also allow precise estimation of arbitrary intervals (down to the dot resolution of the plot, 100 in the example below).
Within the slabinterval family of geoms in tidybayes is the
dots
and dotsinterval
family, which
automatically determine appropriate bin sizes for dotplots and can
calculate quantiles from samples to construct quantile dotplots.
ggdist::stat_dots()
is the variant designed for use on
samples:
m %>%
spread_draws(condition_mean[condition]) %>%
ggplot(aes(x = condition_mean, y = fct_rev(condition))) +
stat_dotsinterval(quantiles = 100)
The idea is to get away from thinking about the posterior as indicating one canonical point or interval, but instead to represent it as (say) 100 approximately equally likely points.
The point_interval()
family of functions follow the
naming scheme [median|mean|mode]_[qi|hdi|hdci]
, and all
work in the same way as median_qi()
: they take a series of
names (or expressions calculated on columns) and summarize those columns
with the corresponding point summary function (median, mean, or mode)
and interval (qi, hdi, or hdci). qi
yields a quantile
interval (a.k.a. equi-tailed interval, central interval, or percentile
interval), hdi
yields one or more highest (posterior)
density interval(s), and hdci
yields a single (possibly)
highest-density continuous interval. These can be used in any
combination desired.
The *_hdi
functions have an additional difference: In
the case of multimodal distributions, they may return multiple intervals
for each probability level. Here are some draws from a multimodal normal
mixture:
Passed through mode_hdi()
, we get multiple intervals at
the 80% probability level:
x | .lower | .upper | .width | .point | .interval |
---|---|---|---|---|---|
-0.0605292 | -1.455671 | 1.540140 | 0.8 | mode | hdi |
-0.0605292 | 3.106254 | 5.005503 | 0.8 | mode | hdi |
This is easier to see when plotted:
multimodal_draws %>%
ggplot(aes(x = x)) +
stat_slab(aes(y = 0)) +
stat_pointinterval(aes(y = -0.5), point_interval = median_qi, .width = c(.95, .80)) +
annotate("text", label = "median, 80% and 95% quantile intervals", x = 6, y = -0.5, hjust = 0, vjust = 0.3) +
stat_pointinterval(aes(y = -0.25), point_interval = mode_hdi, .width = c(.95, .80)) +
annotate("text", label = "mode, 80% and 95% highest-density intervals", x = 6, y = -0.25, hjust = 0, vjust = 0.3) +
xlim(-3.25, 18) +
scale_y_continuous(breaks = NULL)
spread_draws()
supports extracting variables that have
different indices. It automatically matches up indices with the same
name, and duplicates values as necessary to produce one row per all
combination of levels of all indices. For example, we might want to
calculate the difference between each condition mean and the overall
mean. To do that, we can extract draws from the overall mean and all
condition means:
.chain | .iteration | .draw | overall_mean | condition | condition_mean |
---|---|---|---|---|---|
1 | 1 | 1 | 0.0672484 | A | 0.0054368 |
1 | 1 | 1 | 0.0672484 | B | 1.0288799 |
1 | 1 | 1 | 0.0672484 | C | 1.8429073 |
1 | 1 | 1 | 0.0672484 | D | 1.2524837 |
1 | 1 | 1 | 0.0672484 | E | -0.7239581 |
1 | 2 | 2 | 0.0361190 | A | -0.0835896 |
1 | 2 | 2 | 0.0361190 | B | 0.8732236 |
1 | 2 | 2 | 0.0361190 | C | 1.7790774 |
1 | 2 | 2 | 0.0361190 | D | 1.1211601 |
1 | 2 | 2 | 0.0361190 | E | -0.8842005 |
Within each draw, overall_mean
is repeated as necessary
to correspond to every index of condition_mean
. Thus, the
dplyr::mutate()
function can be used to take the
differences over all rows, then we can summarize with
median_qi()
:
m %>%
spread_draws(overall_mean, condition_mean[condition]) %>%
mutate(condition_offset = condition_mean - overall_mean) %>%
median_qi(condition_offset)
condition | condition_offset | .lower | .upper | .width | .point | .interval |
---|---|---|---|---|---|---|
A | -0.4202269 | -1.6850299 | 0.7540214 | 0.95 | median | qi |
B | 0.3627166 | -0.8630461 | 1.5860078 | 0.95 | median | qi |
C | 1.2041326 | -0.0333102 | 2.4400293 | 0.95 | median | qi |
D | 0.3972638 | -0.8850370 | 1.6075581 | 0.95 | median | qi |
E | -1.5037823 | -2.8517632 | -0.3403648 | 0.95 | median | qi |
We can use combinations of variables with difference indices to
generate predictions from the model. In this case, we can combine the
condition means with the residual standard deviation to generate
predictive distributions from the model, then show the distributions
using ggdist::stat_interval()
and compare them to the
data:
m %>%
spread_draws(condition_mean[condition], response_sd) %>%
mutate(y_rep = rnorm(n(), condition_mean, response_sd)) %>%
median_qi(y_rep, .width = c(.95, .8, .5)) %>%
ggplot(aes(y = fct_rev(condition), x = y_rep)) +
geom_interval(aes(xmin = .lower, xmax = .upper)) + #auto-sets aes(color = fct_rev(ordered(.width)))
geom_point(aes(x = response), data = ABC) +
scale_color_brewer()
If this model is well-calibrated, about 95% of the data should be within the outer intervals, 80% in the next-smallest intervals, and 50% in the smallest intervals.
Altogether, data, posterior predictions, and posterior distributions of the means:
draws = m %>%
spread_draws(condition_mean[condition], response_sd)
reps = draws %>%
mutate(y_rep = rnorm(n(), condition_mean, response_sd))
ABC %>%
ggplot(aes(y = condition)) +
stat_interval(aes(x = y_rep), .width = c(.95, .8, .5), data = reps) +
stat_pointinterval(aes(x = condition_mean), .width = c(.95, .66), position = position_nudge(y = -0.3), data = draws) +
geom_point(aes(x = response)) +
scale_color_brewer()
compare_levels()
allows us to compare the value of some
variable across levels of some factor. By default it computes all
pairwise differences, though this can be changed using the
comparison =
argument:
gather_draws
and gather_variables
We might also prefer all model variable names to be in a single
column (long-format) instead of as column names. There are two methods
for obtaining long-format data frames with tidybayes
, whose
use depends on where and how in the data processing chain you might want
to transform into long-format: gather_draws()
and
gather_variables()
. There are also two methods for wide (or
semi-wide) format data frame, spread_draws()
(described
previously) and tidy_draws()
.
gather_draws()
is the counterpart to
spread_draws()
, except it puts all variable names in a
.variable
column and all draws in a .value
column:
.variable | condition | .value | .lower | .upper | .width | .point | .interval |
---|---|---|---|---|---|---|---|
condition_mean | A | 0.1988803 | -0.1422924 | 0.5485277 | 0.95 | median | qi |
condition_mean | B | 1.0064882 | 0.6510522 | 1.3407107 | 0.95 | median | qi |
condition_mean | C | 1.8410297 | 1.4791098 | 2.1868748 | 0.95 | median | qi |
condition_mean | D | 1.0221440 | 0.6812030 | 1.3665086 | 0.95 | median | qi |
condition_mean | E | -0.8897187 | -1.2326187 | -0.5288822 | 0.95 | median | qi |
overall_mean | NA | 0.6331377 | -0.5458905 | 1.8689116 | 0.95 | median | qi |
Note that condition = NA
for the
overall_mean
row, because it does not have an index with
that name in the specification passed to
gather_draws()
.
While this works well if we do not need to perform computations that
involve multiple columns, the semi-wide format returned by
spread_draws()
is very useful for computations that involve
multiple columns names, such as the calculation of the
condition_offset
above. If we want to make intermediate
computations on the format returned by spread_draws
and
then gather variables into one column, we can use
gather_variables()
, which will gather all non-grouped
variables that are not special columns (like .chain
,
.iteration
, or .draw
):
m %>%
spread_draws(overall_mean, condition_mean[condition]) %>%
mutate(condition_offset = condition_mean - overall_mean) %>%
gather_variables() %>%
median_qi()
condition | .variable | .value | .lower | .upper | .width | .point | .interval |
---|---|---|---|---|---|---|---|
A | condition_mean | 0.1988803 | -0.1422924 | 0.5485277 | 0.95 | median | qi |
A | condition_offset | -0.4202269 | -1.6850299 | 0.7540214 | 0.95 | median | qi |
A | overall_mean | 0.6331377 | -0.5458905 | 1.8689116 | 0.95 | median | qi |
B | condition_mean | 1.0064882 | 0.6510522 | 1.3407107 | 0.95 | median | qi |
B | condition_offset | 0.3627166 | -0.8630461 | 1.5860078 | 0.95 | median | qi |
B | overall_mean | 0.6331377 | -0.5458905 | 1.8689116 | 0.95 | median | qi |
C | condition_mean | 1.8410297 | 1.4791098 | 2.1868748 | 0.95 | median | qi |
C | condition_offset | 1.2041326 | -0.0333102 | 2.4400293 | 0.95 | median | qi |
C | overall_mean | 0.6331377 | -0.5458905 | 1.8689116 | 0.95 | median | qi |
D | condition_mean | 1.0221440 | 0.6812030 | 1.3665086 | 0.95 | median | qi |
D | condition_offset | 0.3972638 | -0.8850370 | 1.6075581 | 0.95 | median | qi |
D | overall_mean | 0.6331377 | -0.5458905 | 1.8689116 | 0.95 | median | qi |
E | condition_mean | -0.8897187 | -1.2326187 | -0.5288822 | 0.95 | median | qi |
E | condition_offset | -1.5037823 | -2.8517632 | -0.3403648 | 0.95 | median | qi |
E | overall_mean | 0.6331377 | -0.5458905 | 1.8689116 | 0.95 | median | qi |
Note how overall_mean
is now repeated here for each
condition, because we have performed the gather after spreading model
variables across columns.
Finally, if we want raw model variable names as columns names instead
of having indices split out as their own column names, we can use
tidy_draws()
. Generally speaking
spread_draws()
and gather_draws()
are
typically more useful than tidy_draws()
, but it is provided
as a common method for generating data frames from many types of
Bayesian models, and is used internally by gather_draws()
and spread_draws()
:
.chain | .iteration | .draw | overall_mean | condition_zoffset[1] | condition_zoffset[2] | condition_zoffset[3] | condition_zoffset[4] | condition_zoffset[5] | response_sd | condition_mean_sd | condition_mean[1] | condition_mean[2] | condition_mean[3] | condition_mean[4] | condition_mean[5] | lp__ | accept_stat__ | stepsize__ | treedepth__ | n_leapfrog__ | divergent__ | energy__ |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1 | 1 | 1 | 0.0672484 | -0.0237405 | 0.3693427 | 0.6819936 | 0.4552242 | -0.3038859 | 0.5755117 | 2.6036298 | 0.0054368 | 1.0288799 | 1.842907 | 1.2524837 | -0.7239581 | 2.9270158 | 0.9817927 | 0.0426611 | 7 | 143 | 0 | 0.2678035 |
1 | 2 | 2 | 0.0361190 | -0.0467933 | 0.3272184 | 0.6813103 | 0.4241350 | -0.3597464 | 0.5763601 | 2.5582447 | -0.0835896 | 0.8732236 | 1.779077 | 1.1211601 | -0.8842005 | 3.1964737 | 0.8878822 | 0.0426611 | 4 | 23 | 0 | 4.2190459 |
1 | 3 | 3 | 1.1679039 | -0.8274014 | -0.1562730 | 0.4543476 | -0.0012390 | -1.2202145 | 0.5505738 | 1.3723456 | 0.0324232 | 0.9534434 | 1.791426 | 1.1662036 | -0.5066521 | 0.8403493 | 0.9338917 | 0.0426611 | 6 | 127 | 0 | 2.9461908 |
1 | 4 | 4 | 0.3780767 | -0.2600146 | 0.6026545 | 1.6610365 | 0.6463340 | -1.3953297 | 0.5763916 | 1.0206912 | 0.1126821 | 0.9932009 | 2.073482 | 1.0377842 | -1.0461240 | 2.0521706 | 0.9987809 | 0.0426611 | 7 | 191 | 0 | 1.1268984 |
1 | 5 | 5 | 0.3593583 | -0.2067222 | 0.4829257 | 1.7791565 | 0.7158143 | -1.3419146 | 0.5834732 | 0.9800271 | 0.1567650 | 0.8326386 | 2.102980 | 1.0608757 | -0.9557543 | 1.4163438 | 0.9976543 | 0.0426611 | 5 | 31 | 0 | 0.1266447 |
1 | 6 | 6 | 0.3319676 | -0.1316209 | 0.5627070 | 1.8948677 | 0.6717871 | -1.2872584 | 0.6210914 | 0.8622442 | 0.2184783 | 0.8171585 | 1.965806 | 0.9112122 | -0.7779635 | 0.6605171 | 0.9980223 | 0.0426611 | 5 | 47 | 0 | 1.3179615 |
1 | 7 | 7 | 0.2260568 | 0.0641718 | 1.1099979 | 1.5352936 | 1.0336724 | -1.6173631 | 0.6407225 | 0.7776276 | 0.2759586 | 1.0892218 | 1.419944 | 1.0298690 | -1.0316494 | -2.1515097 | 0.9852859 | 0.0426611 | 6 | 127 | 0 | 5.4325581 |
1 | 8 | 8 | 0.1074734 | -0.0822823 | 0.7369365 | 1.3237459 | 0.7938146 | -0.9159997 | 0.6370637 | 1.1476519 | 0.0130420 | 0.9532199 | 1.626673 | 1.0184961 | -0.9437754 | 1.3717614 | 0.9941660 | 0.0426611 | 5 | 63 | 0 | 5.9998791 |
1 | 9 | 9 | 0.2249078 | -0.0611340 | 0.5712596 | 1.4362387 | 0.7550052 | -0.8494661 | 0.6087050 | 1.1865547 | 0.1523690 | 0.9027386 | 1.929084 | 1.1207627 | -0.7830302 | 2.5843960 | 0.9990174 | 0.0426611 | 5 | 63 | 0 | 0.5832453 |
1 | 10 | 10 | -0.0395213 | 0.1997215 | 0.8854809 | 1.6851815 | 0.9193861 | -0.4676779 | 0.5211262 | 1.1585659 | 0.1918692 | 0.9863667 | 1.912872 | 1.0256481 | -0.5813570 | 1.4153766 | 0.9939692 | 0.0426611 | 6 | 79 | 0 | -0.0699659 |
Combining tidy_draws()
with
gather_variables()
also allows us to derive similar output
to ggmcmc::ggs()
, if desired:
.chain | .iteration | .draw | .variable | .value |
---|---|---|---|---|
1 | 1 | 1 | overall_mean | 0.0672484 |
1 | 2 | 2 | overall_mean | 0.0361190 |
1 | 3 | 3 | overall_mean | 1.1679039 |
1 | 4 | 4 | overall_mean | 0.3780767 |
1 | 5 | 5 | overall_mean | 0.3593583 |
1 | 6 | 6 | overall_mean | 0.3319676 |
1 | 7 | 7 | overall_mean | 0.2260568 |
1 | 8 | 8 | overall_mean | 0.1074734 |
1 | 9 | 9 | overall_mean | 0.2249078 |
1 | 10 | 10 | overall_mean | -0.0395213 |
But again, this approach does not handle variable indices for us
automatically, so using spread_draws()
and
gather_draws()
is generally recommended unless you do not
have variable indices to worry about.
You can use regular expressions in the specifications passed to
spread_draws()
and gather_draws()
to match
multiple columns by passing regex = TRUE
. Our example fit
contains variables named condition_mean[i]
and
condition_zoffset[i]
. We could extract both using a single
regular expression:
condition | condition_mean | condition_zoffset | .chain | .iteration | .draw |
---|---|---|---|---|---|
A | 0.0054368 | -0.0237405 | 1 | 1 | 1 |
A | -0.0835896 | -0.0467933 | 1 | 2 | 2 |
A | 0.0324232 | -0.8274014 | 1 | 3 | 3 |
A | 0.1126821 | -0.2600146 | 1 | 4 | 4 |
A | 0.1567650 | -0.2067222 | 1 | 5 | 5 |
A | 0.2184783 | -0.1316209 | 1 | 6 | 6 |
A | 0.2759586 | 0.0641718 | 1 | 7 | 7 |
A | 0.0130420 | -0.0822823 | 1 | 8 | 8 |
A | 0.1523690 | -0.0611340 | 1 | 9 | 9 |
A | 0.1918692 | 0.1997215 | 1 | 10 | 10 |
This result is equivalent in this case to
spread_draws(c(condition_mean, condition_zoffset)[condition])
,
but does not require us to list each variable explicitly—this can be
useful, for example, in models with naming schemes like
b_[some name]
for coefficients.
To demonstrate drawing fit curves with uncertainty, let’s fit a
slightly naive model to part of the mtcars
dataset using
brms::brm()
:
m_mpg = brm(
mpg ~ hp * cyl,
data = mtcars,
file = "models/tidybayes_m_mpg.rds" # cache model (can be removed)
)
We can draw fit curves with probability bands using
add_epred_draws()
and
ggdist::stat_lineribbon()
:
mtcars %>%
group_by(cyl) %>%
data_grid(hp = seq_range(hp, n = 51)) %>%
add_epred_draws(m_mpg) %>%
ggplot(aes(x = hp, y = mpg, color = ordered(cyl))) +
stat_lineribbon(aes(y = .epred)) +
geom_point(data = mtcars) +
scale_fill_brewer(palette = "Greys") +
scale_color_brewer(palette = "Set2")
Or we can sample a reasonable number of fit lines (say 100) and overplot them:
mtcars %>%
group_by(cyl) %>%
data_grid(hp = seq_range(hp, n = 101)) %>%
# NOTE: this shows the use of ndraws to subsample within add_epred_draws()
# ONLY do this IF you are planning to make spaghetti plots, etc.
# NEVER subsample to a small sample to plot intervals, densities, etc.
add_epred_draws(m_mpg, ndraws = 100) %>%
ggplot(aes(x = hp, y = mpg, color = ordered(cyl))) +
geom_line(aes(y = .epred, group = paste(cyl, .draw)), alpha = .1) +
geom_point(data = mtcars) +
scale_color_brewer(palette = "Dark2")
For more examples of fit line uncertainty, see the corresponding
sections in vignette("tidy-brms")
or
vignette("tidy-rstanarm")
.
point_interval
with
broom::tidy
: A model comparison exampleCombining to_broom_names()
with median_qi()
(or more generally, the point_interval()
family of
functions) makes it easy to compare results against models supported by
broom::tidy()
. For example, let’s compare our model’s fits
for conditional means against an ordinary least squares (OLS)
regression:
Combining emmeans::emmeans
with
broom::tidy
, we can generate tidy-format summaries of
conditional means from the above model:
linear_results = m_linear %>%
emmeans::emmeans(~ condition) %>%
tidy(conf.int = TRUE) %>%
mutate(model = "OLS")
linear_results
condition | estimate | std.error | df | conf.low | conf.high | statistic | p.value | model |
---|---|---|---|---|---|---|---|---|
A | 0.1815842 | 0.173236 | 45 | -0.1673310 | 0.5304993 | 1.048190 | 0.3001485 | OLS |
B | 1.0142144 | 0.173236 | 45 | 0.6652993 | 1.3631296 | 5.854526 | 0.0000005 | OLS |
C | 1.8745839 | 0.173236 | 45 | 1.5256687 | 2.2234990 | 10.820985 | 0.0000000 | OLS |
D | 1.0271794 | 0.173236 | 45 | 0.6782642 | 1.3760946 | 5.929366 | 0.0000004 | OLS |
E | -0.9352260 | 0.173236 | 45 | -1.2841411 | -0.5863108 | -5.398567 | 0.0000024 | OLS |
We can derive corresponding fits from our model:
bayes_results = m %>%
spread_draws(condition_mean[condition]) %>%
median_qi(estimate = condition_mean) %>%
to_broom_names() %>%
mutate(model = "Bayes")
bayes_results
condition | estimate | conf.low | conf.high | .width | .point | .interval | model |
---|---|---|---|---|---|---|---|
A | 0.1988803 | -0.1422924 | 0.5485277 | 0.95 | median | qi | Bayes |
B | 1.0064882 | 0.6510522 | 1.3407107 | 0.95 | median | qi | Bayes |
C | 1.8410297 | 1.4791098 | 2.1868748 | 0.95 | median | qi | Bayes |
D | 1.0221440 | 0.6812030 | 1.3665086 | 0.95 | median | qi | Bayes |
E | -0.8897187 | -1.2326187 | -0.5288822 | 0.95 | median | qi | Bayes |
Here, to_broom_names()
will convert .lower
and .upper
into conf.low
and
conf.high
so the names of the columns we need to make the
comparison (condition
, estimate
,
conf.low
, and conf.high
) all line up easily.
This makes it simple to combine the two tidy data frames together using
bind_rows
, and plot them:
bind_rows(linear_results, bayes_results) %>%
mutate(condition = fct_rev(condition)) %>%
ggplot(aes(y = condition, x = estimate, xmin = conf.low, xmax = conf.high, color = model)) +
geom_pointinterval(position = position_dodge(width = .3))
Observe the shrinkage towards the overall mean in the Bayesian model compared to the OLS model.
bayesplot
using
unspread_draws
and ungather_draws
Functions from other packages might expect draws in the form of a
data frame or matrix with variables as columns and draws as rows. That
is the format returned by tidy_draws()
, but not by
gather_draws()
or spread_draws()
, which split
indices from variables out into columns.
It may be desirable to use the spread_draws()
or
gather_draws()
functions to transform your draws in some
way, and then convert them back into the draw \(\times\) variable format to pass them into
functions from other packages, like bayesplot
. The
unspread_draws()
and ungather_draws()
functions invert spread_draws()
and
gather_draws()
to return a data frame with variable column
names that include indices in them and draws as rows.
As an example, let’s re-do the previous example of
compare_levels()
, but use
bayesplot::mcmc_areas()
to plot the results instead of
ggdist::stat_eye()
. First, the result of
compare_levels()
looks like this:
m %>%
spread_draws(condition_mean[condition]) %>%
compare_levels(condition_mean, by = condition) %>%
head(10)
.chain | .iteration | .draw | condition | condition_mean |
---|---|---|---|---|
1 | 1 | 1 | B - A | 1.0234431 |
1 | 2 | 2 | B - A | 0.9568132 |
1 | 3 | 3 | B - A | 0.9210202 |
1 | 4 | 4 | B - A | 0.8805187 |
1 | 5 | 5 | B - A | 0.6758735 |
1 | 6 | 6 | B - A | 0.5986802 |
1 | 7 | 7 | B - A | 0.8132632 |
1 | 8 | 8 | B - A | 0.9401779 |
1 | 9 | 9 | B - A | 0.7503696 |
1 | 10 | 10 | B - A | 0.7944975 |
To get a version we can pass to bayesplot::mcmc_areas()
,
all we need to do is invert the spread_draws()
call we
started with:
m %>%
spread_draws(condition_mean[condition]) %>%
compare_levels(condition_mean, by = condition) %>%
unspread_draws(condition_mean[condition]) %>%
head(10)
.chain | .iteration | .draw | condition_mean[B - A] | condition_mean[C - A] | condition_mean[C - B] | condition_mean[D - A] | condition_mean[D - B] | condition_mean[D - C] | condition_mean[E - A] | condition_mean[E - B] | condition_mean[E - C] | condition_mean[E - D] |
---|---|---|---|---|---|---|---|---|---|---|---|---|
1 | 1 | 1 | 1.0234431 | 1.837471 | 0.8140274 | 1.2470468 | 0.2236037 | -0.5904237 | -0.7293949 | -1.752838 | -2.566865 | -1.976442 |
1 | 2 | 2 | 0.9568132 | 1.862667 | 0.9058538 | 1.2047497 | 0.2479365 | -0.6579174 | -0.8006109 | -1.757424 | -2.663278 | -2.005361 |
1 | 3 | 3 | 0.9210202 | 1.759003 | 0.8379824 | 1.1337804 | 0.2127602 | -0.6252222 | -0.5390753 | -1.460095 | -2.298078 | -1.672856 |
1 | 4 | 4 | 0.8805187 | 1.960800 | 1.0802812 | 0.9251021 | 0.0445834 | -1.0356979 | -1.1588062 | -2.039325 | -3.119606 | -2.083908 |
1 | 5 | 5 | 0.6758735 | 1.946215 | 1.2703413 | 0.9041107 | 0.2282371 | -1.0421042 | -1.1125193 | -1.788393 | -3.058734 | -2.016630 |
1 | 6 | 6 | 0.5986802 | 1.747328 | 1.1486478 | 0.6927339 | 0.0940537 | -1.0545941 | -0.9964417 | -1.595122 | -2.743770 | -1.689176 |
1 | 7 | 7 | 0.8132632 | 1.143985 | 0.3307217 | 0.7539104 | -0.0593528 | -0.3900745 | -1.3076080 | -2.120871 | -2.451593 | -2.061518 |
1 | 8 | 8 | 0.9401779 | 1.613631 | 0.6734529 | 1.0054541 | 0.0652762 | -0.6081767 | -0.9568174 | -1.896995 | -2.570448 | -1.962271 |
1 | 9 | 9 | 0.7503696 | 1.776715 | 1.0263450 | 0.9683937 | 0.2180241 | -0.8083209 | -0.9353992 | -1.685769 | -2.712114 | -1.903793 |
1 | 10 | 10 | 0.7944975 | 1.721003 | 0.9265058 | 0.8337789 | 0.0392814 | -0.8872244 | -0.7732262 | -1.567724 | -2.494229 | -1.607005 |
We can pass that into bayesplot::mcmc_areas()
directly.
The drop_indices = TRUE
argument to
unspread_draws()
indicates that .chain
,
.iteration
, and .draw
should not be included
in the output:
m %>%
spread_draws(condition_mean[condition]) %>%
compare_levels(condition_mean, by = condition) %>%
unspread_draws(condition_mean[condition], drop_indices = TRUE) %>%
bayesplot::mcmc_areas()
If you are instead working with tidy draws generated by
gather_draws()
or gather_variables()
, the
ungather_draws()
function will transform those draws into
the draw \(\times\) variable format. It
has the same syntax as unspread_draws()
.
emmeans
(formerly
lsmeans
)The emmeans::emmeans()
function provides a convenient
syntax for generating marginal estimates from a model, including
numerous types of contrasts. It also supports some Bayesian modeling
packages, like MCMCglmm
, rstanarm
, and
brms
. However, it does not provide draws in a tidy format.
The gather_emmeans_draws()
function converts output from
emmeans
into a tidy format, keeping the
emmeans
reference grid and adding a .value
column with long-format draws.
(Another approach to using emmeans
contrast methods is
to use emmeans_comparison()
to convert emmeans contrast
methods into comparison functions that can be used with
tidybayes::compare_levels()
. See the documentation of
emmeans_comparison()
for more information).
rstanarm
or brms
Both rstanarm
and brms
behave similarly
when used with emmeans::emmeans()
. The example below uses
rstanarm
, but should work similarly for
brms
.
Given this rstanarm
model:
We can use emmeans::emmeans()
to get conditional means
with uncertainty:
condition | .value | .lower | .upper | .width | .point | .interval |
---|---|---|---|---|---|---|
A | 0.1869666 | -0.1579702 | 0.5263756 | 0.95 | median | qi |
B | 1.0065441 | 0.6589598 | 1.3593600 | 0.95 | median | qi |
C | 1.8685270 | 1.5218745 | 2.2204854 | 0.95 | median | qi |
D | 1.0259611 | 0.6830677 | 1.3888880 | 0.95 | median | qi |
E | -0.9367195 | -1.2835908 | -0.5918623 | 0.95 | median | qi |
Or emmeans::emmeans()
with
emmeans::contrast()
to do all pairwise comparisons:
m_rst %>%
emmeans::emmeans( ~ condition) %>%
emmeans::contrast(method = "pairwise") %>%
gather_emmeans_draws() %>%
median_qi()
contrast | .value | .lower | .upper | .width | .point | .interval |
---|---|---|---|---|---|---|
A - B | -0.8215470 | -1.3158948 | -0.3238283 | 0.95 | median | qi |
A - C | -1.6820415 | -2.1826435 | -1.2104144 | 0.95 | median | qi |
A - D | -0.8373571 | -1.3324811 | -0.3472881 | 0.95 | median | qi |
A - E | 1.1286972 | 0.6148364 | 1.6118181 | 0.95 | median | qi |
B - C | -0.8631374 | -1.3552500 | -0.3387682 | 0.95 | median | qi |
B - D | -0.0157834 | -0.5212377 | 0.4987398 | 0.95 | median | qi |
B - E | 1.9529800 | 1.4481070 | 2.4472639 | 0.95 | median | qi |
C - D | 0.8428591 | 0.3416806 | 1.3415446 | 0.95 | median | qi |
C - E | 2.8072029 | 2.3292265 | 3.3096635 | 0.95 | median | qi |
D - E | 1.9664081 | 1.4597655 | 2.4630503 | 0.95 | median | qi |
See the documentation for emmeans::pairwise.emmc()
for a
list of the numerous contrast types supported by
emmeans::emmeans()
.
As before, we can plot the results instead of using a table:
m_rst %>%
emmeans::emmeans( ~ condition) %>%
emmeans::contrast(method = "pairwise") %>%
gather_emmeans_draws() %>%
ggplot(aes(x = .value, y = contrast)) +
stat_halfeye()
gather_emmeans_draws()
also supports
emm_list
objects, which contain estimates from different
reference grids (see emmeans::ref_grid()
for more
information on reference grids). An additional column with the default
name of .grid
is added to indicate the reference grid for
each row in the output:
condition | contrast | .grid | .value | .lower | .upper | .width | .point | .interval |
---|---|---|---|---|---|---|---|---|
A | NA | emmeans | 0.1869666 | -0.1579702 | 0.5263756 | 0.95 | median | qi |
B | NA | emmeans | 1.0065441 | 0.6589598 | 1.3593600 | 0.95 | median | qi |
C | NA | emmeans | 1.8685270 | 1.5218745 | 2.2204854 | 0.95 | median | qi |
D | NA | emmeans | 1.0259611 | 0.6830677 | 1.3888880 | 0.95 | median | qi |
E | NA | emmeans | -0.9367195 | -1.2835908 | -0.5918623 | 0.95 | median | qi |
NA | A - B | contrasts | -0.8215470 | -1.3158948 | -0.3238283 | 0.95 | median | qi |
NA | A - C | contrasts | -1.6820415 | -2.1826435 | -1.2104144 | 0.95 | median | qi |
NA | A - D | contrasts | -0.8373571 | -1.3324811 | -0.3472881 | 0.95 | median | qi |
NA | A - E | contrasts | 1.1286972 | 0.6148364 | 1.6118181 | 0.95 | median | qi |
NA | B - C | contrasts | -0.8631374 | -1.3552500 | -0.3387682 | 0.95 | median | qi |
NA | B - D | contrasts | -0.0157834 | -0.5212377 | 0.4987398 | 0.95 | median | qi |
NA | B - E | contrasts | 1.9529800 | 1.4481070 | 2.4472639 | 0.95 | median | qi |
NA | C - D | contrasts | 0.8428591 | 0.3416806 | 1.3415446 | 0.95 | median | qi |
NA | C - E | contrasts | 2.8072029 | 2.3292265 | 3.3096635 | 0.95 | median | qi |
NA | D - E | contrasts | 1.9664081 | 1.4597655 | 2.4630503 | 0.95 | median | qi |
MCMCglmm
Let’s do the same example as above again, this time using
MCMCglmm::MCMCglmm()
instead of rstanarm
. The
only different when using MCMCglmm()
is that to use
MCMCglmm()
with emmeans()
you must also pass
the original data used to fit the model to the emmeans()
call (see vignette("models", package = "emmeans"))
for more
information).
First, we’ll fit the model:
# MCMCglmm does not support tibbles directly,
# so we convert ABC to a data.frame on the way in
m_glmm = MCMCglmm::MCMCglmm(response ~ condition, data = as.data.frame(ABC))
Now we can use emmeans()
and
gather_emmeans_draws()
exactly as we did with
rstanarm
, but we need to include a data
argument in the emmeans()
call: