Automatic Guide Generation

AutoDiagonalNormal

class AutoDiagonalNormal(model, prefix='auto', init_strategy=<function init_to_median>)[source]

Bases: numpyro.contrib.autoguide.AutoContinuous

This implementation of AutoContinuous uses a Normal distribution with a diagonal covariance matrix to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Usage:

guide = AutoDiagonalNormal(rng, model, ...)
svi = SVI(model, guide, ...)
median(params)[source]

Returns the posterior median value of each latent variable.

Parameters:params (dict) – A dict containing parameter values.
Returns:A dict mapping sample site name to median tensor.
Return type:dict
quantiles(params, quantiles)[source]

Returns posterior quantiles each latent variable. Example:

print(guide.quantiles(opt_state, [0.05, 0.5, 0.95]))
Parameters:
  • opt_state – Current state of the optimizer.
  • quantiles (torch.Tensor or list) – A list of requested quantiles between 0 and 1.
Returns:

A dict mapping sample site name to a list of quantile values.

Return type:

dict

AutoIAFNormal

class AutoIAFNormal(model, prefix='auto', init_strategy=<function init_to_median>, num_flows=3, **arn_kwargs)[source]

Bases: numpyro.contrib.autoguide.AutoContinuous

This implementation of AutoContinuous uses a Diagonal Normal distribution transformed via a InverseAutoregressiveTransform to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Usage:

guide = AutoIAFNormal(rng, model, get_params, hidden_dims=[20], skip_connections=True, ...)
svi_init, svi_update, _ = svi(model, guide, ...)
Parameters:
  • rng (jax.random.PRNGKey) – random key to be used as the source of randomness to initialize the guide.
  • model (callable) – a generative model.
  • prefix (str) – a prefix that will be prefixed to all param internal sites.
  • init_strategy (callable) – A per-site initialization function.
  • num_flows (int) – the number of flows to be used, defaults to 3.
  • **arn_kwargs

    keywords for constructing autoregressive neural networks, which includes:

    • hidden_dims (list[int]) - the dimensionality of the hidden units per layer. Defaults to [latent_size, latent_size].
    • skip_connections (bool) - whether to add skip connections from the input to the output of each flow. Defaults to False.
    • nonlinearity (callable) - the nonlinearity to use in the feedforward network. Defaults to jax.experimental.stax.Relu().