Note
Click here to download the full example code
Example: AR2 process¶
In this example we show how to use jax.lax.scan
to avoid writing a (slow) Python for-loop. In this toy
example, with --num-data=1000
, the improvement is
of almost 10x.
To demonstrate, we will be implementing an AR2 process. The idea is that we have some times series
\[y_0, y_1, ..., y_T\]
and we seek parameters \(c\), \(\alpha_1\), and \(\alpha_2\) such that for each \(t\) between \(2\) and \(T\), we have
\[y_t = c + \alpha_1 y_{t-1} + \alpha_2 y_{t-2} + \epsilon_t\]
where \(\epsilon_t\) is an error term.
import argparse
import os
import time
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import jax
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
matplotlib.use("Agg")
def ar2(y, unroll_loop=False):
alpha_1 = numpyro.sample("alpha_1", dist.Normal(0, 1))
alpha_2 = numpyro.sample("alpha_2", dist.Normal(0, 1))
const = numpyro.sample("const", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.Normal(0, 1))
def transition_fn(carry, y):
y_1, y_2 = carry
pred = const + alpha_1 * y_1 + alpha_2 * y_2
return (y, y_1), pred
if unroll_loop:
preds = []
for i in range(2, len(y)):
preds.append(const + alpha_1 * y[i - 1] + alpha_2 * y[i - 2])
preds = jnp.asarray(preds)
else:
_, preds = jax.lax.scan(transition_fn, (y[1], y[0]), y[2:])
mu = numpyro.deterministic("mu", preds)
numpyro.sample("obs", dist.Normal(mu, sigma), obs=y[2:])
def run_inference(model, args, rng_key, y):
start = time.time()
sampler = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(
sampler,
num_warmup=args.num_warmup,
num_samples=args.num_samples,
num_chains=args.num_chains,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
mcmc.run(rng_key, y=y, unroll_loop=args.unroll_loop)
mcmc.print_summary()
print("\nMCMC elapsed time:", time.time() - start)
return mcmc.get_samples()
def main(args):
# generate artifical dataset
num_data = args.num_data
t = np.arange(0, num_data)
y = np.sin(t) + np.random.randn(num_data) * 0.1
# do inference
rng_key, _ = random.split(random.PRNGKey(0))
samples = run_inference(ar2, args, rng_key, y)
# do prediction
mean_prediction = samples["mu"].mean(axis=0)
# make plots
fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)
# plot training data
ax.plot(t, y, color="blue", label="True values")
# plot mean prediction
# note that we can't make predictions for the first two points,
# because they don't have lagged values to use for prediction.
ax.plot(t[2:], mean_prediction, color="orange", label="Mean predictions")
ax.set(xlabel="time", ylabel="y", title="AR2 process")
ax.legend()
plt.savefig("ar2_plot.pdf")
if __name__ == "__main__":
assert numpyro.__version__.startswith("0.9.1")
parser = argparse.ArgumentParser(description="AR2 example")
parser.add_argument("--num-data", nargs="?", default=142, type=int)
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
parser.add_argument("--num-chains", nargs="?", default=1, type=int)
parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
parser.add_argument(
"--unroll-loop",
action="store_true",
help="whether to unroll for-loop (note: slower)",
)
args = parser.parse_args()
numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)
main(args)