Components of a model

Based on the book section 4.2. A language for describing models.

Variables

Data

Observable things.

Parameters

Not observable or known.

“Joint generative model”

We define each variable either in terms of a probability distribution, or in terms of its relationship to the other variables.

Taken together, this system of definitions is what McElreath calls a “joint generative model”.

Penguins example

Let’s model bill_length for Chinstrap penguins. First, we won’t use any predictor variables - just looking for the mean.

library(palmerpenguins)
library(dplyr)

Attaching package: 'dplyr'
The following objects are masked from 'package:stats':

    filter, lag
The following objects are masked from 'package:base':

    intersect, setdiff, setequal, union
library(ggplot2)
library(ggdist)
theme_set(theme_minimal())

penguins <- penguins |>
  filter(!is.na(sex)) |>
  filter(species == "Chinstrap") 

ggplot(penguins, aes(bill_depth_mm)) +
  geom_dotsinterval()

Defining the model

\(\text{billdepth}_i \sim N(\mu, \sigma)\)

This just says that bill depth is normally distributed with some mean and some standard deviation. We can observe bill depth (data), but we don’t know the mean or standard deviation (parameters).

For the parameters, we need priors.

\(\mu \sim N(50, 15)\)

This is to say that \(\mu\) is some number drawn from a normal distribution centered on 50:

data.frame(mu = rnorm(1000, mean = 20, sd = 6)) |>
  ggplot(aes(x = mu)) +
  geom_dotsinterval() +
  ggtitle("Simulations from prior for mu")

$(0, 20)$

And here we’re saying that we figure \(sigma\) is potentially uniformly distributed ranging from 0-20. Standard deviations have to be positive, and otherwise this is a very broad range.

data.frame(sigma = runif(1000, 0, 20)) |>
  ggplot(aes(x = sigma)) +
  geom_dotsinterval() +
  ggtitle("Simulations from prior for sigma")

Exploring the generative model

sample_mu <- rnorm(1000, 20, 6)
sample_sigmas <- runif(1000, 0, 20)

data.frame(simulated_bill_depths = 
             rnorm(1000, 
                   sample_mu,
                   sample_sigmas)) |>
  ggplot(aes(simulated_bill_depths)) +
  geom_dotsinterval() +
  ggtitle("Simlated bill depths from priors")

This is our simulation of expected bill depths before explicitly taking into account the data (although we did look at the density plot before specifying the priors).

Taking into account the data

See here for translations of rethinking code to brms.

library(brms)

bill_depth_brm <-
  brm(
    family = gaussian,
    bill_depth_mm ~ 1,
    data = penguins,
    prior = c(
      prior(normal(20, 6), class = Intercept),
      prior(uniform(0, 20), class = sigma, ub = 20)
    ),
    iter = 1000
  )
Loading required package: Rcpp
Loading 'brms' package (version 2.21.0). Useful instructions
can be found by typing help('brms'). A more detailed introduction
to the package is available through vignette('brms_overview').

Attaching package: 'brms'
The following objects are masked from 'package:ggdist':

    dstudent_t, pstudent_t, qstudent_t, rstudent_t
The following object is masked from 'package:stats':

    ar
plot(bill_depth_brm)

summary(bill_depth_brm)
 Family: gaussian 
  Links: mu = identity; sigma = identity 
Formula: bill_depth_mm ~ 1 
   Data: penguins (Number of observations: 68) 
  Draws: 4 chains, each with iter = 1000; warmup = 500; thin = 1;
         total post-warmup draws = 2000

Regression Coefficients:
          Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept    18.42      0.14    18.13    18.70 1.00     1390     1244

Further Distributional Parameters:
      Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma     1.16      0.10     0.98     1.36 1.00     1501     1307

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

Exploring the posterior

posterior <- as_draws_df(bill_depth_brm)

head(posterior)
# A draws_df: 6 iterations, 1 chains, and 5 variables
  b_Intercept sigma Intercept lprior lp__
1          19   1.2        19   -5.7 -111
2          18   1.1        18   -5.7 -111
3          18   1.1        18   -5.7 -111
4          18   1.1        18   -5.7 -111
5          18   1.1        18   -5.7 -111
6          18   1.0        18   -5.7 -111
# ... hidden reserved variables {'.chain', '.iteration', '.draw'}
ggplot(posterior, aes(Intercept)) +
  stat_dotsinterval()

ggplot(posterior, aes(sigma)) +
  stat_dotsinterval()

posterior_summary(bill_depth_brm)
               Estimate   Est.Error         Q2.5       Q97.5
b_Intercept   18.421378 0.140472861   18.1294656   18.696755
sigma          1.155849 0.098822889    0.9816352    1.359573
Intercept     18.421378 0.140472861   18.1294656   18.696755
lprior        -5.741316 0.006166867   -5.7550261   -5.730020
lp__        -111.280959 1.026213063 -114.1671862 -110.321475

Penguins with a predictor

OK, now let’s model bill_depth as a function of bill_length.

ggplot(penguins, aes(bill_length_mm, 
                     bill_depth_mm)) +
  geom_point()

Specifying the model

\(\text{billdepth}_i \sim N(\mu, \sigma)\)

Here, again, we say that bill depth is normally distributed with some mean and standard deviation.

\(\mu_i = \alpha + \beta(\text{billlength}_i - \text{mean}(\text{billlength}))\)

This addition makes it into a linear model. Instead of estimating \(\mu\) from the data, we say that the mean of bill depth varies with bill length as linear function with an intercept \(\alpha\) and a slope \(\beta\).

\(\alpha\) and \(\beta\) are now additional parameters that we will estimate and therefore need to set priors for.

\(\alpha \sim N(20, 6)\)

\(\beta \sim N(0, 5)\)

\(\sigma \sim \text{Uniform}(0, 20)\)

Exploring the generative model

prior_draw <- data.frame(
  bill_length = seq(min(penguins$bill_length_mm), max(penguins$bill_length_mm), length.out = 100),
  intercept = rnorm(1, 20, 6),
  beta = rnorm(1, 0, 5),
  sigma = runif(1, 0, 20)
) |>
  mutate(mu = intercept + beta * (bill_length - mean(bill_length))) |>
  rowwise() |>
  mutate(sim_depth = rnorm(1, mu, sigma)) 
ggplot(prior_draw, aes(bill_length, mu)) +
  geom_line() +
  geom_ribbon(aes(ymin = mu - sigma,
                  ymax = mu + sigma),
              alpha = .3) +
  geom_point(aes(y = sim_depth)) +
  ggtitle("Simulated bill depths",
          subtitle = "Based on a single draw from the priors")

Fitting the model (with data)

depth_length_brm <- brm(
  family = gaussian,
  data = penguins,
  formula = bill_depth_mm ~ bill_length_mm,
  prior = c(
      prior(normal(20, 6), class = Intercept),
      prior(normal(0, 5), class = b),
      prior(uniform(0, 20), class = sigma, ub = 20)),
  iter = 1000
)
plot(depth_length_brm)

summary(depth_length_brm)
 Family: gaussian 
  Links: mu = identity; sigma = identity 
Formula: bill_depth_mm ~ bill_length_mm 
   Data: penguins (Number of observations: 68) 
  Draws: 4 chains, each with iter = 1000; warmup = 500; thin = 1;
         total post-warmup draws = 2000

Regression Coefficients:
               Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept          7.57      1.60     4.29    10.62 1.00     1828     1302
bill_length_mm     0.22      0.03     0.16     0.29 1.00     1803     1301

Further Distributional Parameters:
      Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma     0.88      0.08     0.75     1.05 1.00     2180     1472

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

Exploring the posterior

lm_posterior <- as_draws_df(depth_length_brm)

head(lm_posterior)
# A draws_df: 6 iterations, 1 chains, and 6 variables
  b_Intercept b_bill_length_mm sigma Intercept lprior lp__
1         7.4             0.23  0.96        19   -8.3  -97
2         7.0             0.24  0.94        19   -8.3  -96
3         7.8             0.22  0.89        19   -8.3  -96
4         7.6             0.23  0.79        19   -8.3  -96
5         7.4             0.22  0.93        18   -8.3  -96
6         8.9             0.20  0.92        18   -8.3  -95
# ... hidden reserved variables {'.chain', '.iteration', '.draw'}
ggplot(lm_posterior, aes(b_Intercept)) +
  stat_dotsinterval()

ggplot(lm_posterior, aes(b_bill_length_mm)) +
  stat_dotsinterval()

ggplot(lm_posterior, aes(sigma)) +
  stat_dotsinterval()

posterior_summary(depth_length_brm)
                    Estimate   Est.Error        Q2.5       Q97.5
b_Intercept        7.5678375 1.597685969   4.2885750  10.6186846
b_bill_length_mm   0.2222723 0.032584203   0.1603925   0.2888031
sigma              0.8829648 0.077046605   0.7505312   1.0523065
Intercept         18.4222455 0.108770766  18.2085474  18.6326642
lprior            -8.2705540 0.004796696  -8.2807804  -8.2617956
lp__             -95.6697446 1.293997087 -99.0571858 -94.2660701