Help with reparameterisation (case study with the Dirichlet-Multinomial)

Hi there @nick, this is to continue our conversation on joint species distribution modelling but I’ll give some background for others here.

Model

Say we want to model count data Y_{ij} of species j in site i as a function of site environment X_i and species traits T_j. Let’s consider one environment and one trait for now.

Rather than the negative binomial, I want to explore the Dirichlet-Multinomial (DMN) and greta remains as one of the few that allows this flexibility :slight_smile: Our generative model is thus:

Y_{ij} \sim \text{DMN}(N_i,~\alpha_{ij})

Below I’ll fix N_i = N = 100 for all sites but this of course can vary by site. I’ll also parameterise the concentration \alpha_{ij} = \mu_{ij} \theta, where \mu are the relative abundances (simplex that sum to one) and \theta the precision (which can also vary by site but here fixing to be the same for all sites).

Thus our linear predictor is:

\text{multilogit}(\mu_{ij}) = a_i + \log c_j -\frac{1}{2}\left( \frac{x_i - u_j}{t_j} \right)^2 \\ \log c_j \sim \text{Normal}(\delta T_j,~ \sigma_{\log c}) \\ u_j \sim \text{Normal}(\kappa T_j,~ \sigma_{u}) \\ \log t_j \sim \text{Normal}(\gamma T_j,~ \sigma_{\log t}) \,,

where a_i are site random intercepts, and the remaining components in the linear predictor come from the Gaussian function — because I’m modelling abundance nonlinearly across environments. Then, traits are allowed to moderate each of the Gaussian parameters. Earlier, you demonstrated that the reparameterisation with u_over_t = u_j / t_j helps with sampling, but I want them separately here to examine the (potentially) different trait effects \kappa and \gamma. But maybe there is still a clever way to retrieve them even with the u_j / t_j parameterisation…?

Anyhow, below is the code to generate data and fit a greta model. It demonstrates a few key issues that I’m facing, even after a few prior reparameterisation and scaling covariates:

  • HMC sampling gets increasingly difficult when u_j contains extreme true values (i.e., species with very extreme optima)
  • This is probably because of more correlated posterior between u and t; we can do the u_over_t trick but (1) it only seems to work when u is reasonably within [-2, 2] but still faces difficulty when u goes beyond these bounds and (2) I don’t know how to retrieve \kappa and \gamma from this reparameterisation

After the code I’ll speculate other possible causes of the issues.

Code

library(greta)


# Simulate data -----------------------------------------------------------
set.seed(1)

n_site <- 100  # number of sites
n_sp   <- 11   # number of species 
n_env  <- 1    # number of environmental covariates
n_trait <- 1   # number of traits

# Generate environmental gradient
x <- replicate(n_env, rnorm(n_site))   

# Generate traits
# in contrast with the environment, for this we use regular intervals
# because traits later determine species optima, tolerances and peaks
# we want to control them at intervals to examine if the species parameters 
# becomes more difficult to estimate as traits become more extreme
# here, I use -2 to 2, but making trait range more extreme generally makes it
# more difficult to sample the posterior due to collinear posteriors (see below)
z <- cbind(seq(-2, 2, length.out = n_sp))

# Species optima
# fixing trait effect on optima (kappa) to be one
kappa_true <- 1
mu_u_true  <- kappa_true %*% t(z)
sd_u_true  <- 0.5
u_true     <- matrix(rnorm(n_sp, mu_u_true, sd_u_true))

# Species tolerances
# fixing trait effect on tolerances (gamma) to be zero
gamma_true <- 0
mu_t_true  <- gamma_true %*% t(z)
sd_t_true  <- 0.5
t_true     <- exp(matrix(rnorm(n_sp, mu_t_true, sd_t_true)))

# Species maxima / peaks
# fixing trait effect on peaks (delta) to be zero
# but note that the residual SD is much higher to reflect that species peaks
# tend to be more variable than other parameters, as traits tend to explain less
# of species' overall abundances
delta_true <- 0
mu_c_true  <- delta_true %*% t(z)
sd_c_true  <- 2
c_true     <- matrix(rnorm(n_sp, mu_c_true, sd_c_true))

# Random site variations
# not sure if this causes non-identifiability when there is also 
# theta (precision) in a Dirichlet-Multinomial,
# which can also a site(row)-specific parameter
sd_Pi_true <- 0.1
Pi_true <- rnorm(n_site, 0, sd_Pi_true)

# The Gaussian function
# doing it in a more general form
# by expansion: log eta = c - 0.5*(u/t)^2 + (u/t^2)*x - 0.5/(t^2)*(x^2)
# (this came from another main project where there are multiple species and 
# multiple environments, but maybe there is a better/cleaner way?)
f0_true  <- matrix(1, n_site, n_env) %*% t(0.5 * (u_true / t_true)^2)
fx_true  <- x %*% t(u_true / (t_true^2))
fx2_true <- x^2 %*% t(1/(2*(t_true^2)))

# linear predictor
eta_true <- - (f0_true - fx_true + fx2_true)
eta_true <- sweep(eta_true, 2, c_true, "+")  # add random species intercepts
eta_true <- sweep(eta_true, 1, Pi_true, "+") # add random site intercepts

# Generate outcomes
# Some literature generate the relative abundances/simplex using rdirichlet, 
# not sure but I'm generating the simplex using the inverse multilogit
mu_true        <- VGAM::multilogitlink(eta_true, inverse = TRUE)
# we parameterise the concentration (alpha) of Dirichlet-Multinomial as mu*theta
# but here doing log alpha = log mu + log theta hoping for more stable chains
log_theta_true <- log(100)
# alpha <- mu * exp(log_theta_true)
log_alpha_true <- log(mu_true) + log_theta_true
alpha_true <- exp(log_alpha_true)

# Fixing total abundance to be the same for all sites, though this can vary 
# similarly, theta above can vary by site
n_trials <- 100 

Y_obs <- extraDistr::rdirmnom(n_site, n_trials, alpha_true)

# matplot(x, Y_obs)   # quick look




# Greta model -------------------------------------------------------------

# priors
# random site intercepts
# SDs below will use the reparameterised version of half Cauchy
# and random variables generated with non-central parameterisation
sd_Pi_unif <- uniform(0, pi/2, dim = n_env)
sd_Pi      <- 2.5 * tan(sd_Pi_unif)
z_Pi       <- normal(0, 1, dim = n_site)
Pi         <- z_Pi * sd_Pi

# Gaussian niche function
# species tolerances
z_gamma   <- normal(0, 1, dim = c(n_env, n_trait))
sd_gamma  <- 1
gamma     <- sd_gamma * z_gamma
sd_t_unif <- uniform(0, pi/2, dim = n_env)
sd_t      <- 2.5 * tan(sd_t_unif)
mu_t      <- gamma %*% t(z)
z_t       <- normal(0, 1, dim = c(n_env, n_sp))
t         <- t(exp(mu_t + sweep(z_t, 1, sd_t, "*")))

# species optima
z_kappa   <- normal(0, 1, dim = c(n_env, n_trait))
sd_kappa  <- 1
kappa     <- sd_kappa * z_kappa
sd_u_unif <- uniform(0, pi/2, dim = n_env)
sd_u      <- 2.5 * tan(sd_u_unif)
mu_u      <- kappa %*% t(z)
z_u       <- normal(0, 1, dim = c(n_env, n_sp))
u         <- t(mu_u + sweep(z_u, 1, sd_u, "*"))

# species peaks
z_delta   <- normal(0, 1, dim = c(1, n_trait))
sd_delta  <- 1
delta     <- sd_delta * z_delta
sd_c_unif <- uniform(0, pi/2)
sd_c      <- 2.5 * tan(sd_c_unif)
mu_c      <- delta %*% t(z)
z_c       <- normal(0, 1, dim = c(1, n_sp))
c         <- t(mu_c + sweep(z_c, 1, sd_c, "*"))

# The Gaussian function
f0 <- ones(n_site, n_env) %*% t(0.5 * (u / t)^2)
fx <- x %*% t(u / (t^2))
fx2 <- x^2 %*% t(1/(2*(t^2)))

# linear predictor
eta <- - (f0 - fx + fx2)
eta <- sweep(eta, 2, c, "+")  # add random species intercepts
eta <- sweep(eta, 1, Pi, "+") # add random site intercepts

# likelihood: dirichlet multinomial
# mu <- dirichlet(exp(eta))
mu <- imultilogit(eta)
# prior for log theta: N(log(100), srqt(2.5)) 
# not sure if there is anything better?
log_theta <- log(100) + sqrt(2.5) * normal(0, 1)  
log_alpha <- log(mu) + log_theta
alpha <- exp(log_alpha)

distribution(Y_obs) <- dirichlet_multinomial(n_trials, alpha)

# Model
m <- model(
    log_theta,
    # Pi,
    c, u, t, 
    kappa, gamma, delta,
    sd_c, sd_u, sd_t, sd_Pi)

# Sampling
# needed to increase steps and warmup for mix chains better
# even so, not very efficient and L seems to be more crucial than warmup number
# this is already the more optimistic case because in the real data we have more
# sites and more species (and more noise), in that case neither increase step
# size and warmup helped
draws <- mcmc(m,
              sampler = hmc(20, 25),
              warmup = 3000)



library(bayesplot)
mcmc_intervals(draws, vars(-starts_with("theta")))
mcmc_trace(draws, vars(starts_with("u")))
mcmc_trace(draws, vars(starts_with("t[")))
mcmc_trace(draws, vars(starts_with("c")))

# the trace plots looks okay, and parameters reasonably close to their true
# values, but we still have correlated posteriors that made sampling not very
# efficient

mcmc_pairs(draws, c("u[1,1]", "c[1,1]", "t[1,1]"))

Although this model sampled well on the surface, you’ll see that it wasn’t very efficient and the correlated posteriors could very well be the main problem for a much larger real dataset that I’m analysing.

Possible causes

  1. The model could be mispecified, in particular:
    a. Not sure if site-specific random intercept \alpha_i is identifiable, because of the way DMN works — there is the precision \theta_i which is also a row parameter
    b. Maybe I’m missing some (population-level) intercept terms in the submodels \log c_j, u_j and t_j
    c. I hope I’m calculating the concentration \alpha_{ij} correctly…
  2. Some priors may be nonsensical, in particular that of \log \theta
  3. Do we use the inverse multilogit link, or generate \mu_{ij} using another Dirichlet? Still a bit confused by the literature out there…

Would appreciate any help/black magic from you and the community. Thanks in advance!