Pyro Primitives

param

param(name, init_value=None, **kwargs)[source]

Annotate the given site as an optimizable parameter for use with jax.experimental.optimizers. For an example of how param statements can be used in inference algorithms, refer to svi().

Parameters:
  • name (str) – name of site.
  • init_value (numpy.ndarray) – initial value specified by the user. Note that the onus of using this to initialize the optimizer is on the user / inference algorithm, since there is no global parameter store in NumPyro.
Returns:

value for the parameter. Unless wrapped inside a handler like substitute, this will simply return the initial value.

sample

sample(name, fn, obs=None, rng_key=None, sample_shape=(), infer=None)[source]

Returns a random sample from the stochastic function fn. This can have additional side effects when wrapped inside effect handlers like substitute.

Note

By design, sample primitive is meant to be used inside a NumPyro model. Then seed handler is used to inject a random state to fn. In those situations, rng_key keyword will take no effect.

Parameters:
  • name (str) – name of the sample site.
  • fn – a stochastic function that returns a sample.
  • obs (numpy.ndarray) – observed value
  • rng_key (jax.random.PRNGKey) – an optional random key for fn.
  • sample_shape – Shape of samples to be drawn.
  • infer (dict) – an optional dictionary containing additional information for inference algorithms. For example, if fn is a discrete distribution, setting infer={‘enumerate’: ‘parallel’} to tell MCMC marginalize this discrete latent site.
Returns:

sample from the stochastic fn.

plate

class plate(name, size, subsample_size=None, dim=None)[source]

Construct for annotating conditionally independent variables. Within a plate context manager, sample sites will be automatically broadcasted to the size of the plate. Additionally, a scale factor might be applied by certain inference algorithms if subsample_size is specified.

Parameters:
  • name (str) – Name of the plate.
  • size (int) – Size of the plate.
  • subsample_size (int) – Optional argument denoting the size of the mini-batch. This can be used to apply a scaling factor by inference algorithms. e.g. when computing ELBO using a mini-batch.
  • dim (int) – Optional argument to specify which dimension in the tensor is used as the plate dim. If None (default), the leftmost available dim is allocated.

plate_stack

plate_stack(prefix, sizes, rightmost_dim=-1)[source]

Create a contiguous stack of plate s with dimensions:

rightmost_dim - len(sizes), ..., rightmost_dim
Parameters:
  • prefix (str) – Name prefix for plates.
  • sizes (iterable) – An iterable of plate sizes.
  • rightmost_dim (int) – The rightmost dim, counting from the right.

deterministic

deterministic(name, value)[source]

Used to designate deterministic sites in the model. Note that most effect handlers will not operate on deterministic sites (except trace()), so deterministic sites should be side-effect free. The use case for deterministic nodes is to record any values in the model execution trace.

Parameters:
  • name (str) – name of the deterministic site.
  • value (numpy.ndarray) – deterministic value to record in the trace.

factor

factor(name, log_factor)[source]

Factor statement to add arbitrary log probability factor to a probabilistic model.

Parameters:
  • name (str) – Name of the trivial sample.
  • log_factor (numpy.ndarray) – A possibly batched log probability factor.

module

module(name, nn, input_shape=None)[source]

Declare a stax style neural network inside a model so that its parameters are registered for optimization via param() statements.

Parameters:
  • name (str) – name of the module to be registered.
  • nn (tuple) – a tuple of (init_fn, apply_fn) obtained by a stax constructor function.
  • input_shape (tuple) – shape of the input taken by the neural network.
Returns:

a apply_fn with bound parameters that takes an array as an input and returns the neural network transformed output array.

scan

scan(f, init, xs, length=None, reverse=False)[source]

This primitive scans a function over the leading array axes of xs while carrying along state. See jax.lax.scan() for more information.

Usage:

>>> import numpy as np
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.contrib.control_flow import scan
>>>
>>> def gaussian_hmm(y=None, T=10):
...     def transition(x_prev, y_curr):
...         x_curr = numpyro.sample('x', dist.Normal(x_prev, 1))
...         y_curr = numpyro.sample('y', dist.Normal(x_curr, 1), obs=y_curr)
...         return x_curr, (x_curr, y_curr)
...
...     x0 = numpyro.sample('x_0', dist.Normal(0, 1))
...     _, (x, y) = scan(transition, x0, y, length=T)
...     return (x, y)
>>>
>>> # here we do some quick tests
>>> with numpyro.handlers.seed(rng_seed=0):
...     x, y = gaussian_hmm(np.arange(10.))
>>> assert x.shape == (10,) and y.shape == (10,)
>>> assert np.all(y == np.arange(10))
>>>
>>> with numpyro.handlers.seed(rng_seed=0):  # generative
...     x, y = gaussian_hmm()
>>> assert x.shape == (10,) and y.shape == (10,)

Warning

This is an experimental utility function that allows users to use JAX control flow with NumPyro’s effect handlers. Currently, sample and deterministic sites within the scan body f are supported. If you notice that any effect handlers or distributions are unsupported, please file an issue.

Note

It is ambiguous to align scan dimension inside a plate context. So the following pattern won’t be supported

with numpyro.plate('N', 10):
    last, ys = scan(f, init, xs)

All plate statements should be put inside f. For example, the corresponding working code is

def g(*args, **kwargs):
    with numpyro.plate('N', 10):
        return f(*arg, **kwargs)

last, ys = scan(g, init, xs)

Note

Nested scan is currently not supported.

Note

We can scan over discrete latent variables in f. The joint density is evaluated using parallel-scan (reference [1]) over time dimension, which reduces parallel complexity to O(log(length)).

Currently, only the equivalence to markov(history_size=1) is supported. A trace of scan with discrete latent variables will contain the following sites:

  • init sites: those sites belong to the first trace of f. Each of
    them will have name prefixed with _init/.
  • scanned sites: those sites collect the values of the remaining scan
    loop over f. An addition time dimension _time_foo will be added to those sites, where foo is the name of the first site appeared in f.

Not all transition functions f are supported. All of the restrictions from Pyro’s enumeration tutorial [2] still apply here. In addition, there should not have any site outside of scan depend on the first output of scan (the last carry value).

** References **

  1. Temporal Parallelization of Bayesian Smoothers, Simo Sarkka, Angel F. Garcia-Fernandez (https://arxiv.org/abs/1905.13002)
  2. Inference with Discrete Latent Variables (http://pyro.ai/examples/enumeration.html#Dependencies-among-plates)
Parameters:
  • f (callable) – a function to be scanned.
  • init – the initial carrying state
  • xs – the values over which we scan along the leading axis. This can be any JAX pytree (e.g. list/dict of arrays).
  • length – optional value specifying the length of xs but can be used when xs is an empty pytree (e.g. None)
  • reverse (bool) – optional boolean specifying whether to run the scan iteration forward (the default) or in reverse
Returns:

output of scan, quoted from jax.lax.scan() docs: “pair of type (c, [b]) where the first element represents the final loop carry value and the second element represents the stacked outputs of the second output of f when scanned over the leading axis of the inputs”.