Stochastic Variational Inference (SVI)

We offer a brief overview of the three most commonly used ELBO implementations in NumPyro:

  • Trace_ELBO is our basic ELBO implementation.

  • TraceMeanField_ELBO is like Trace_ELBO but computes part of the ELBO analytically if doing so is possible.

  • TraceGraph_ELBO offers variance reduction strategies for models with discrete latent variables. Generally speaking, this ELBO should always be used for models with discrete latent variables.

  • TraceEnum_ELBO offers variable enumeration strategies for models with discrete latent variables. Generally speaking, this ELBO should always be used for models with discrete latent variables when enumeration is possible.

class SVI(model, guide, optim, loss, **static_kwargs)[source]

Bases: object

Stochastic Variational Inference given an ELBO loss objective.

References

  1. SVI Part I: An Introduction to Stochastic Variational Inference in Pyro, (http://pyro.ai/examples/svi_part_i.html)

Example:

>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.distributions import constraints
>>> from numpyro.infer import Predictive, SVI, Trace_ELBO

>>> def model(data):
...     f = numpyro.sample("latent_fairness", dist.Beta(10, 10))
...     with numpyro.plate("N", data.shape[0] if data is not None else 10):
...         numpyro.sample("obs", dist.Bernoulli(f), obs=data)

>>> def guide(data):
...     alpha_q = numpyro.param("alpha_q", 15., constraint=constraints.positive)
...     beta_q = numpyro.param("beta_q", lambda rng_key: random.exponential(rng_key),
...                            constraint=constraints.positive)
...     numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))

>>> data = jnp.concatenate([jnp.ones(6), jnp.zeros(4)])
>>> optimizer = numpyro.optim.Adam(step_size=0.0005)
>>> svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
>>> svi_result = svi.run(random.PRNGKey(0), 2000, data)
>>> params = svi_result.params
>>> inferred_mean = params["alpha_q"] / (params["alpha_q"] + params["beta_q"])
>>> # use guide to make predictive
>>> predictive = Predictive(model, guide=guide, params=params, num_samples=1000)
>>> samples = predictive(random.PRNGKey(1), data=None)
>>> # get posterior samples
>>> predictive = Predictive(guide, params=params, num_samples=1000)
>>> posterior_samples = predictive(random.PRNGKey(1), data=None)
>>> # use posterior samples to make predictive
>>> predictive = Predictive(model, posterior_samples, params=params, num_samples=1000)
>>> samples = predictive(random.PRNGKey(1), data=None)
Parameters:
  • model – Python callable with Pyro primitives for the model.

  • guide – Python callable with Pyro primitives for the guide (recognition network).

  • optim

    An instance of _NumpyroOptim, a jax.example_libraries.optimizers.Optimizer or an Optax GradientTransformation. If you pass an Optax optimizer it will automatically be wrapped using numpyro.optim.optax_to_numpyro().

    >>> from optax import adam, chain, clip
    >>> svi = SVI(model, guide, chain(clip(10.0), adam(1e-3)), loss=Trace_ELBO())
    

  • loss – ELBO loss, i.e. negative Evidence Lower Bound, to minimize.

  • static_kwargs – static arguments for the model / guide, i.e. arguments that remain constant during fitting.

Returns:

tuple of (init_fn, update_fn, evaluate).

init(rng_key, *args, init_params=None, **kwargs)[source]

Gets the initial SVI state.

Parameters:
  • rng_key (jax.random.PRNGKey) – random number generator seed.

  • args – arguments to the model / guide (these can possibly vary during the course of fitting).

  • init_params (dict) – if not None, initialize numpyro.param sites with values from this dictionary instead of using init_value in numpyro.param primitives.

  • kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).

Returns:

the initial SVIState

get_params(svi_state)[source]

Gets values at param sites of the model and guide.

Parameters:

svi_state – current state of SVI.

Returns:

the corresponding parameters

update(svi_state, *args, forward_mode_differentiation=False, **kwargs)[source]

Take a single step of SVI (possibly on a batch / minibatch of data), using the optimizer.

Parameters:
  • svi_state – current state of SVI.

  • args – arguments to the model / guide (these can possibly vary during the course of fitting).

  • forward_mode_differentiation – boolean flag indicating whether to use forward mode differentiation. Defaults to False.

  • kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).

Returns:

tuple of (svi_state, loss).

stable_update(svi_state, *args, forward_mode_differentiation=False, **kwargs)[source]

Similar to update() but returns the current state if the the loss or the new state contains invalid values.

Parameters:
  • svi_state – current state of SVI.

  • args – arguments to the model / guide (these can possibly vary during the course of fitting).

  • forward_mode_differentiation – boolean flag indicating whether to use forward mode differentiation. Defaults to False.

  • kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).

Returns:

tuple of (svi_state, loss).

run(rng_key, num_steps, *args, progress_bar=True, stable_update=False, forward_mode_differentiation=False, init_state=None, init_params=None, **kwargs)[source]

(EXPERIMENTAL INTERFACE) Run SVI with num_steps iterations, then return the optimized parameters and the stacked losses at every step. If num_steps is large, setting progress_bar=False can make the run faster.

Note

For a complex training process (e.g. the one requires early stopping, epoch training, varying args/kwargs,…), we recommend to use the more flexible methods init(), update(), evaluate() to customize your training procedure.

Parameters:
  • rng_key (jax.random.PRNGKey) – random number generator seed.

  • num_steps (int) – the number of optimization steps.

  • args – arguments to the model / guide

  • progress_bar (bool) – Whether to enable progress bar updates. Defaults to True.

  • stable_update (bool) – whether to use stable_update() to update the state. Defaults to False.

  • forward_mode_differentiation (bool) – whether to use forward-mode differentiation or reverse-mode differentiation. By default, we use reverse mode but the forward mode can be useful in some cases to improve the performance. In addition, some control flow utility on JAX such as jax.lax.while_loop or jax.lax.fori_loop only supports forward-mode differentiation. See JAX’s The Autodiff Cookbook for more information.

  • init_state (SVIState) –

    if not None, begin SVI from the final state of previous SVI run. Usage:

    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
    svi_result = svi.run(random.PRNGKey(0), 2000, data)
    # upon inspection of svi_result the user decides that the model has not converged
    # continue from the end of the previous svi run rather than beginning again from iteration 0
    svi_result = svi.run(random.PRNGKey(1), 2000, data, init_state=svi_result.state)
    

  • init_params (dict) – if not None, initialize numpyro.param sites with values from this dictionary instead of using init_value in numpyro.param primitives.

  • kwargs – keyword arguments to the model / guide

Returns:

a namedtuple with fields params and losses where params holds the optimized values at numpyro.param sites, and losses is the collected loss during the process.

Return type:

SVIRunResult

evaluate(svi_state, *args, **kwargs)[source]

Take a single step of SVI (possibly on a batch / minibatch of data).

Parameters:
  • svi_state – current state of SVI.

  • args – arguments to the model / guide (these can possibly vary during the course of fitting).

  • kwargs – keyword arguments to the model / guide.

Returns:

evaluate ELBO loss given the current parameter values (held within svi_state.optim_state).

SVIState = <class 'numpyro.infer.svi.SVIState'>
A namedtuple() consisting of the following fields:
  • optim_state - current optimizer’s state.

  • mutable_state - extra state to store values of “mutable” sites

  • rng_key - random number generator seed used for the iteration.

SVIRunResult = <class 'numpyro.infer.svi.SVIRunResult'>
A namedtuple() consisting of the following fields:
  • params - the optimized parameters.

  • state - the last SVIState

  • losses - the losses collected at every step.

ELBO

class ELBO(num_particles=1, vectorize_particles=True)[source]

Bases: object

Base class for all ELBO objectives.

Subclasses should implement either loss() or loss_with_mutable_state().

Parameters:
  • num_particles – The number of particles/samples used to form the ELBO (gradient) estimators.

  • vectorize_particles – Whether to use jax.vmap to compute ELBOs over the num_particles-many particles in parallel. If False use jax.lax.map. Defaults to True.

can_infer_discrete = False
loss(rng_key, param_map, model, guide, *args, **kwargs)[source]

Evaluates the ELBO with an estimator that uses num_particles many samples/particles.

Parameters:
  • rng_key (jax.random.PRNGKey) – random number generator seed.

  • param_map (dict) – dictionary of current parameter values keyed by site name.

  • model – Python callable with NumPyro primitives for the model.

  • guide – Python callable with NumPyro primitives for the guide.

  • args – arguments to the model / guide (these can possibly vary during the course of fitting).

  • kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).

Returns:

negative of the Evidence Lower Bound (ELBO) to be minimized.

loss_with_mutable_state(rng_key, param_map, model, guide, *args, **kwargs)[source]

Like loss() but also update and return the mutable state, which stores the values at mutable() sites.

Parameters:
  • rng_key (jax.random.PRNGKey) – random number generator seed.

  • param_map (dict) – dictionary of current parameter values keyed by site name.

  • model – Python callable with NumPyro primitives for the model.

  • guide – Python callable with NumPyro primitives for the guide.

  • args – arguments to the model / guide (these can possibly vary during the course of fitting).

  • kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).

Returns:

dictionay containing ELBO loss and the mutable state

Trace_ELBO

class Trace_ELBO(num_particles=1, vectorize_particles=True, multi_sample_guide=False)[source]

Bases: ELBO

A trace implementation of ELBO-based SVI. The estimator is constructed along the lines of references [1] and [2]. There are no restrictions on the dependency structure of the model or the guide.

This is the most basic implementation of the Evidence Lower Bound, which is the fundamental objective in Variational Inference. This implementation has various limitations (for example it only supports random variables with reparameterized samplers) but can be used as a template to build more sophisticated loss objectives.

For more details, refer to http://pyro.ai/examples/svi_part_i.html.

References:

  1. Automated Variational Inference in Probabilistic Programming, David Wingate, Theo Weber

  2. Black Box Variational Inference, Rajesh Ranganath, Sean Gerrish, David M. Blei

Parameters:
  • num_particles – The number of particles/samples used to form the ELBO (gradient) estimators.

  • vectorize_particles – Whether to use jax.vmap to compute ELBOs over the num_particles-many particles in parallel. If False use jax.lax.map. Defaults to True.

  • multi_sample_guide – Whether to make an assumption that the guide proposes multiple samples.

loss_with_mutable_state(rng_key, param_map, model, guide, *args, **kwargs)[source]

Like loss() but also update and return the mutable state, which stores the values at mutable() sites.

Parameters:
  • rng_key (jax.random.PRNGKey) – random number generator seed.

  • param_map (dict) – dictionary of current parameter values keyed by site name.

  • model – Python callable with NumPyro primitives for the model.

  • guide – Python callable with NumPyro primitives for the guide.

  • args – arguments to the model / guide (these can possibly vary during the course of fitting).

  • kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).

Returns:

dictionay containing ELBO loss and the mutable state

TraceEnum_ELBO

class TraceEnum_ELBO(num_particles=1, max_plate_nesting=inf, vectorize_particles=True)[source]

Bases: ELBO

(EXPERIMENTAL) A TraceEnum implementation of ELBO-based SVI. The gradient estimator is constructed along the lines of reference [1] specialized to the case of the ELBO. It supports arbitrary dependency structure for the model and guide. Fine-grained conditional dependency information as recorded in the trace is used to reduce the variance of the gradient estimator. In particular provenance tracking [2] is used to find the cost terms that depend on each non-reparameterizable sample site. Enumerated variables are eliminated using the TVE algorithm for plated factor graphs [3].

References

[1] Storchastic: A Framework for General Stochastic Automatic Differentiation,

Emile van Kriekenc, Jakub M. Tomczak, Annette ten Teije

[2] Nonstandard Interpretations of Probabilistic Programs for Efficient Inference,

David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind

[3] Tensor Variable Elimination for Plated Factor Graphs,

Fritz Obermeyer, Eli Bingham, Martin Jankowiak, Justin Chiu, Neeraj Pradhan, Alexander M. Rush, Noah Goodman

can_infer_discrete = True
loss(rng_key, param_map, model, guide, *args, **kwargs)[source]

Evaluates the ELBO with an estimator that uses num_particles many samples/particles.

Parameters:
  • rng_key (jax.random.PRNGKey) – random number generator seed.

  • param_map (dict) – dictionary of current parameter values keyed by site name.

  • model – Python callable with NumPyro primitives for the model.

  • guide – Python callable with NumPyro primitives for the guide.

  • args – arguments to the model / guide (these can possibly vary during the course of fitting).

  • kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).

Returns:

negative of the Evidence Lower Bound (ELBO) to be minimized.

TraceGraph_ELBO

class TraceGraph_ELBO(num_particles=1, vectorize_particles=True)[source]

Bases: ELBO

A TraceGraph implementation of ELBO-based SVI. The gradient estimator is constructed along the lines of reference [1] specialized to the case of the ELBO. It supports arbitrary dependency structure for the model and guide. Fine-grained conditional dependency information as recorded in the trace is used to reduce the variance of the gradient estimator. In particular provenance tracking [2] is used to find the cost terms that depend on each non-reparameterizable sample site.

References

[1] Gradient Estimation Using Stochastic Computation Graphs,

John Schulman, Nicolas Heess, Theophane Weber, Pieter Abbeel

[2] Nonstandard Interpretations of Probabilistic Programs for Efficient Inference,

David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind

can_infer_discrete = True
loss(rng_key, param_map, model, guide, *args, **kwargs)[source]

Evaluates the ELBO with an estimator that uses num_particles many samples/particles.

Parameters:
  • rng_key (jax.random.PRNGKey) – random number generator seed.

  • param_map (dict) – dictionary of current parameter values keyed by site name.

  • model – Python callable with NumPyro primitives for the model.

  • guide – Python callable with NumPyro primitives for the guide.

  • args – arguments to the model / guide (these can possibly vary during the course of fitting).

  • kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).

Returns:

negative of the Evidence Lower Bound (ELBO) to be minimized.

TraceMeanField_ELBO

class TraceMeanField_ELBO(num_particles=1, vectorize_particles=True)[source]

Bases: ELBO

A trace implementation of ELBO-based SVI. This is currently the only ELBO estimator in NumPyro that uses analytic KL divergences when those are available.

Warning

This estimator may give incorrect results if the mean-field condition is not satisfied. The mean field condition is a sufficient but not necessary condition for this estimator to be correct. The precise condition is that for every latent variable z in the guide, its parents in the model must not include any latent variables that are descendants of z in the guide. Here ‘parents in the model’ and ‘descendants in the guide’ is with respect to the corresponding (statistical) dependency structure. For example, this condition is always satisfied if the model and guide have identical dependency structures.

loss_with_mutable_state(rng_key, param_map, model, guide, *args, **kwargs)[source]

Like loss() but also update and return the mutable state, which stores the values at mutable() sites.

Parameters:
  • rng_key (jax.random.PRNGKey) – random number generator seed.

  • param_map (dict) – dictionary of current parameter values keyed by site name.

  • model – Python callable with NumPyro primitives for the model.

  • guide – Python callable with NumPyro primitives for the guide.

  • args – arguments to the model / guide (these can possibly vary during the course of fitting).

  • kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).

Returns:

dictionay containing ELBO loss and the mutable state

RenyiELBO

class RenyiELBO(alpha=0, num_particles=2)[source]

Bases: ELBO

An implementation of Renyi’s \(\alpha\)-divergence variational inference following reference [1]. In order for the objective to be a strict lower bound, we require \(\alpha \ge 0\). Note, however, that according to reference [1], depending on the dataset \(\alpha < 0\) might give better results. In the special case \(\alpha = 0\), the objective function is that of the important weighted autoencoder derived in reference [2].

Note

Setting \(\alpha < 1\) gives a better bound than the usual ELBO.

Parameters:
  • alpha (float) – The order of \(\alpha\)-divergence. Here \(\alpha \neq 1\). Default is 0.

  • num_particles – The number of particles/samples used to form the objective (gradient) estimator. Default is 2.

  • vectorize_particles – Whether to use jax.vmap to compute ELBOs over the num_particles-many particles in parallel. If False use jax.lax.map. Defaults to True.

Example:

def model(data):
    with numpyro.plate("batch", 10000, subsample_size=100):
        latent = numpyro.sample("latent", dist.Normal(0, 1))
        batch = numpyro.subsample(data, event_dim=0)
        numpyro.sample("data", dist.Bernoulli(logits=latent), obs=batch)

def guide(data):
    w_loc = numpyro.param("w_loc", 1.)
    w_scale = numpyro.param("w_scale", 1.)
    with numpyro.plate("batch", 10000, subsample_size=100):
        batch = numpyro.subsample(data, event_dim=0)
        loc = w_loc * batch
        scale = jnp.exp(w_scale * batch)
        numpyro.sample("latent", dist.Normal(loc, scale))

elbo = RenyiELBO(num_particles=10)
svi = SVI(model, guide, optax.adam(0.1), elbo)

References:

  1. Renyi Divergence Variational Inference, Yingzhen Li, Richard E. Turner

  2. Importance Weighted Autoencoders, Yuri Burda, Roger Grosse, Ruslan Salakhutdinov

loss(rng_key, param_map, model, guide, *args, **kwargs)[source]

Evaluates the ELBO with an estimator that uses num_particles many samples/particles.

Parameters:
  • rng_key (jax.random.PRNGKey) – random number generator seed.

  • param_map (dict) – dictionary of current parameter values keyed by site name.

  • model – Python callable with NumPyro primitives for the model.

  • guide – Python callable with NumPyro primitives for the guide.

  • args – arguments to the model / guide (these can possibly vary during the course of fitting).

  • kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).

Returns:

negative of the Evidence Lower Bound (ELBO) to be minimized.