Note
Click here to download the full example code
Example: CJS Capture-Recapture Model for Ecological Data¶
This example is ported from [8].
We show how to implement several variants of the Cormack-Jolly-Seber (CJS) [4, 5, 6] model used in ecology to analyze animal capture-recapture data. For a discussion of these models see reference [1].
We make use of two datasets:
the European Dipper (Cinclus cinclus) data from reference [2] (this is Norway’s national bird).
the meadow voles data from reference [3].
Compare to the Stan implementations in [7].
References
Kery, M., & Schaub, M. (2011). Bayesian population analysis using WinBUGS: a hierarchical perspective. Academic Press.
Lebreton, J.D., Burnham, K.P., Clobert, J., & Anderson, D.R. (1992). Modeling survival and testing biological hypotheses using marked animals: a unified approach with case studies. Ecological monographs, 62(1), 67-118.
Nichols, Pollock, Hines (1984) The use of a robust capture-recapture design in small mammal population studies: A field example with Microtus pennsylvanicus. Acta Theriologica 29:357-365.
Cormack, R.M., 1964. Estimates of survival from the sighting of marked animals. Biometrika 51, 429-438.
Jolly, G.M., 1965. Explicit estimates from capture-recapture data with both death and immigration-stochastic model. Biometrika 52, 225-247.
Seber, G.A.F., 1965. A note on the multiple recapture census. Biometrika 52, 249-259.
https://github.com/stan-dev/example-models/tree/master/BPA/Ch.07
import argparse
import os
from jax import random
import jax.numpy as jnp
from jax.scipy.special import expit, logit
import numpyro
from numpyro import handlers
from numpyro.contrib.control_flow import scan
import numpyro.distributions as dist
from numpyro.examples.datasets import DIPPER_VOLE, load_dataset
from numpyro.infer import HMC, MCMC, NUTS
from numpyro.infer.reparam import LocScaleReparam
Our first and simplest CJS model variant only has two continuous (scalar) latent random variables: i) the survival probability phi; and ii) the recapture probability rho. These are treated as fixed effects with no temporal or individual/group variation.
def model_1(capture_history, sex):
N, T = capture_history.shape
phi = numpyro.sample("phi", dist.Uniform(0.0, 1.0)) # survival probability
rho = numpyro.sample("rho", dist.Uniform(0.0, 1.0)) # recapture probability
def transition_fn(carry, y):
first_capture_mask, z = carry
with numpyro.plate("animals", N, dim=-1):
with handlers.mask(mask=first_capture_mask):
mu_z_t = first_capture_mask * phi * z + (1 - first_capture_mask)
# NumPyro exactly sums out the discrete states z_t.
z = numpyro.sample(
"z",
dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),
infer={"enumerate": "parallel"},
)
mu_y_t = rho * z
numpyro.sample(
"y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y
)
first_capture_mask = first_capture_mask | y.astype(bool)
return (first_capture_mask, z), None
z = jnp.ones(N, dtype=jnp.int32)
# we use this mask to eliminate extraneous log probabilities
# that arise for a given individual before its first capture.
first_capture_mask = capture_history[:, 0].astype(bool)
# NB swapaxes: we move time dimension of `capture_history` to the front to scan over it
scan(
transition_fn,
(first_capture_mask, z),
jnp.swapaxes(capture_history[:, 1:], 0, 1),
)
In our second model variant there is a time-varying survival probability phi_t for T-1 of the T time periods of the capture data; each phi_t is treated as a fixed effect.
def model_2(capture_history, sex):
N, T = capture_history.shape
rho = numpyro.sample("rho", dist.Uniform(0.0, 1.0)) # recapture probability
def transition_fn(carry, y):
first_capture_mask, z = carry
# note that phi_t needs to be outside the plate, since
# phi_t is shared across all N individuals
phi_t = numpyro.sample("phi", dist.Uniform(0.0, 1.0))
with numpyro.plate("animals", N, dim=-1):
with handlers.mask(mask=first_capture_mask):
mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask)
# NumPyro exactly sums out the discrete states z_t.
z = numpyro.sample(
"z",
dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),
infer={"enumerate": "parallel"},
)
mu_y_t = rho * z
numpyro.sample(
"y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y
)
first_capture_mask = first_capture_mask | y.astype(bool)
return (first_capture_mask, z), None
z = jnp.ones(N, dtype=jnp.int32)
# we use this mask to eliminate extraneous log probabilities
# that arise for a given individual before its first capture.
first_capture_mask = capture_history[:, 0].astype(bool)
# NB swapaxes: we move time dimension of `capture_history` to the front to scan over it
scan(
transition_fn,
(first_capture_mask, z),
jnp.swapaxes(capture_history[:, 1:], 0, 1),
)
In our third model variant there is a survival probability phi_t for T-1 of the T time periods of the capture data (just like in model_2), but here each phi_t is treated as a random effect.
def model_3(capture_history, sex):
N, T = capture_history.shape
phi_mean = numpyro.sample(
"phi_mean", dist.Uniform(0.0, 1.0)
) # mean survival probability
phi_logit_mean = logit(phi_mean)
# controls temporal variability of survival probability
phi_sigma = numpyro.sample("phi_sigma", dist.Uniform(0.0, 10.0))
rho = numpyro.sample("rho", dist.Uniform(0.0, 1.0)) # recapture probability
def transition_fn(carry, y):
first_capture_mask, z = carry
with handlers.reparam(config={"phi_logit": LocScaleReparam(0)}):
phi_logit_t = numpyro.sample(
"phi_logit", dist.Normal(phi_logit_mean, phi_sigma)
)
phi_t = expit(phi_logit_t)
with numpyro.plate("animals", N, dim=-1):
with handlers.mask(mask=first_capture_mask):
mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask)
# NumPyro exactly sums out the discrete states z_t.
z = numpyro.sample(
"z",
dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),
infer={"enumerate": "parallel"},
)
mu_y_t = rho * z
numpyro.sample(
"y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y
)
first_capture_mask = first_capture_mask | y.astype(bool)
return (first_capture_mask, z), None
z = jnp.ones(N, dtype=jnp.int32)
# we use this mask to eliminate extraneous log probabilities
# that arise for a given individual before its first capture.
first_capture_mask = capture_history[:, 0].astype(bool)
# NB swapaxes: we move time dimension of `capture_history` to the front to scan over it
scan(
transition_fn,
(first_capture_mask, z),
jnp.swapaxes(capture_history[:, 1:], 0, 1),
)
In our fourth model variant we include group-level fixed effects for sex (male, female).
def model_4(capture_history, sex):
N, T = capture_history.shape
# survival probabilities for males/females
phi_male = numpyro.sample("phi_male", dist.Uniform(0.0, 1.0))
phi_female = numpyro.sample("phi_female", dist.Uniform(0.0, 1.0))
# we construct a N-dimensional vector that contains the appropriate
# phi for each individual given its sex (female = 0, male = 1)
phi = sex * phi_male + (1.0 - sex) * phi_female
rho = numpyro.sample("rho", dist.Uniform(0.0, 1.0)) # recapture probability
def transition_fn(carry, y):
first_capture_mask, z = carry
with numpyro.plate("animals", N, dim=-1):
with handlers.mask(mask=first_capture_mask):
mu_z_t = first_capture_mask * phi * z + (1 - first_capture_mask)
# NumPyro exactly sums out the discrete states z_t.
z = numpyro.sample(
"z",
dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),
infer={"enumerate": "parallel"},
)
mu_y_t = rho * z
numpyro.sample(
"y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y
)
first_capture_mask = first_capture_mask | y.astype(bool)
return (first_capture_mask, z), None
z = jnp.ones(N, dtype=jnp.int32)
# we use this mask to eliminate extraneous log probabilities
# that arise for a given individual before its first capture.
first_capture_mask = capture_history[:, 0].astype(bool)
# NB swapaxes: we move time dimension of `capture_history` to the front to scan over it
scan(
transition_fn,
(first_capture_mask, z),
jnp.swapaxes(capture_history[:, 1:], 0, 1),
)
In our final model variant we include both fixed group effects and fixed time effects for the survival probability phi: logit(phi_t) = beta_group + gamma_t We need to take care that the model is not overparameterized; to do this we effectively let a single scalar beta encode the difference in male and female survival probabilities.
def model_5(capture_history, sex):
N, T = capture_history.shape
# phi_beta controls the survival probability differential
# for males versus females (in logit space)
phi_beta = numpyro.sample("phi_beta", dist.Normal(0.0, 10.0))
phi_beta = sex * phi_beta
rho = numpyro.sample("rho", dist.Uniform(0.0, 1.0)) # recapture probability
def transition_fn(carry, y):
first_capture_mask, z = carry
phi_gamma_t = numpyro.sample("phi_gamma", dist.Normal(0.0, 10.0))
phi_t = expit(phi_beta + phi_gamma_t)
with numpyro.plate("animals", N, dim=-1):
with handlers.mask(mask=first_capture_mask):
mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask)
# NumPyro exactly sums out the discrete states z_t.
z = numpyro.sample(
"z",
dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),
infer={"enumerate": "parallel"},
)
mu_y_t = rho * z
numpyro.sample(
"y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y
)
first_capture_mask = first_capture_mask | y.astype(bool)
return (first_capture_mask, z), None
z = jnp.ones(N, dtype=jnp.int32)
# we use this mask to eliminate extraneous log probabilities
# that arise for a given individual before its first capture.
first_capture_mask = capture_history[:, 0].astype(bool)
# NB swapaxes: we move time dimension of `capture_history` to the front to scan over it
scan(
transition_fn,
(first_capture_mask, z),
jnp.swapaxes(capture_history[:, 1:], 0, 1),
)
Do inference
models = {
name[len("model_") :]: model
for name, model in globals().items()
if name.startswith("model_")
}
def run_inference(model, capture_history, sex, rng_key, args):
if args.algo == "NUTS":
kernel = NUTS(model)
elif args.algo == "HMC":
kernel = HMC(model)
mcmc = MCMC(
kernel,
num_warmup=args.num_warmup,
num_samples=args.num_samples,
num_chains=args.num_chains,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
mcmc.run(rng_key, capture_history, sex)
mcmc.print_summary()
return mcmc.get_samples()
def main(args):
# load data
if args.dataset == "dipper":
capture_history, sex = load_dataset(DIPPER_VOLE, split="dipper", shuffle=False)[
1
]()
elif args.dataset == "vole":
if args.model in ["4", "5"]:
raise ValueError(
"Cannot run model_{} on meadow voles data, since we lack sex "
"information for these animals.".format(args.model)
)
(capture_history,) = load_dataset(DIPPER_VOLE, split="vole", shuffle=False)[1]()
sex = None
else:
raise ValueError("Available datasets are 'dipper' and 'vole'.")
N, T = capture_history.shape
print(
"Loaded {} capture history for {} individuals collected over {} time periods.".format(
args.dataset, N, T
)
)
model = models[args.model]
rng_key = random.PRNGKey(args.rng_seed)
run_inference(model, capture_history, sex, rng_key, args)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="CJS capture-recapture model for ecological data"
)
parser.add_argument(
"-m",
"--model",
default="1",
type=str,
help="one of: {}".format(", ".join(sorted(models.keys()))),
)
parser.add_argument("-d", "--dataset", default="dipper", type=str)
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
parser.add_argument("--num-chains", nargs="?", default=1, type=int)
parser.add_argument(
"--rng_seed", default=0, type=int, help="random number generator seed"
)
parser.add_argument(
"--algo", default="NUTS", type=str, help='whether to run "NUTS" or "HMC"'
)
args = parser.parse_args()
main(args)