Stochastic Variational Inference (SVI)

class SVI(model, guide, loss, optim, **kwargs)[source]

Bases: object

Stochastic Variational Inference given an ELBo loss objective.

Parameters:
  • model – Python callable with Pyro primitives for the model.
  • guide – Python callable with Pyro primitives for the guide (recognition network).
  • loss – ELBo loss, i.e. negative Evidence Lower Bound, to minimize.
  • optim – an instance of _NumpyroOptim.
  • **kwargs – static arguments for the model / guide, i.e. arguments that remain constant during fitting.
Returns:

tuple of (init_fn, update_fn, evaluate).

init(rng, model_args=(), guide_args=())[source]
Parameters:
  • rng (jax.random.PRNGKey) – random number generator seed.
  • model_args (tuple) – arguments to the model (these can possibly vary during the course of fitting).
  • guide_args (tuple) – arguments to the guide (these can possibly vary during the course of fitting).
Returns:

tuple containing initial SVIState, and get_params, a callable that transforms unconstrained parameter values from the optimizer to the specified constrained domain

get_params(svi_state)[source]

Gets values at param sites of the model and guide.

Parameters:svi_state – current state of the optimizer.
update(svi_state, model_args=(), guide_args=())[source]

Take a single step of SVI (possibly on a batch / minibatch of data), using the optimizer.

Parameters:
  • svi_state – current state of SVI.
  • model_args (tuple) – dynamic arguments to the model.
  • guide_args (tuple) – dynamic arguments to the guide.
Returns:

tuple of (svi_state, loss).

evaluate(svi_state, model_args=(), guide_args=())[source]

Take a single step of SVI (possibly on a batch / minibatch of data).

Parameters:
  • svi_state – current state of SVI.
  • model_args (tuple) – arguments to the model (these can possibly vary during the course of fitting).
  • guide_args (tuple) – arguments to the guide (these can possibly vary during the course of fitting).
Returns:

evaluate ELBo loss given the current parameter values (held within svi_state.optim_state).

ELBo

elbo(rng, param_map, model, guide, model_args, guide_args, kwargs)[source]

This is the most basic implementation of the Evidence Lower Bound, which is the fundamental objective in Variational Inference. This implementation has various limitations (for example it only supports random variables with reparameterized samplers) but can be used as a template to build more sophisticated loss objectives.

For more details, refer to http://pyro.ai/examples/svi_part_i.html.

Parameters:
  • rng (jax.random.PRNGKey) – random number generator seed.
  • param_map (dict) – dictionary of current parameter values keyed by site name.
  • model – Python callable with Pyro primitives for the model.
  • guide – Python callable with Pyro primitives for the guide (recognition network).
  • model_args (tuple) – arguments to the model (these can possibly vary during the course of fitting).
  • guide_args (tuple) – arguments to the guide (these can possibly vary during the course of fitting).
  • kwargs (dict) – static keyword arguments to the model / guide.
Returns:

negative of the Evidence Lower Bound (ELBo) to be minimized.