Inference Utilities

predictive

predictive(rng, model, posterior_samples, return_sites=None, *args, **kwargs)[source]

Run model by sampling latent parameters from posterior_samples, and return values at sample sites from the forward run. By default, only sites not contained in posterior_samples are returned. This can be modified by changing the return_sites keyword argument.

Warning

The interface for the predictive function is experimental, and might change in the future.

Parameters:
  • rng (jax.random.PRNGKey) – seed to draw samples
  • model – Python callable containing Pyro primitives.
  • posterior_samples (dict) – dictionary of samples from the posterior.
  • return_sites (list) – sites to return; by default only sample sites not present in posterior_samples are returned.
  • args – model arguments.
  • kwargs – model kwargs.
Returns:

dict of samples from the predictive distribution.

log_density

log_density(model, model_args, model_kwargs, params, skip_dist_transforms=False)[source]

Computes log of joint density for the model given latent values params.

Parameters:
  • model – Python callable containing NumPyro primitives.
  • model_args (tuple) – args provided to the model.
  • model_kwargs` (dict) – kwargs provided to the model.
  • params (dict) – dictionary of current parameter values keyed by site name.
  • skip_dist_transforms (bool) – whether to compute log probability of a site (if its prior is a transformed distribution) in its base distribution domain.
Returns:

log of joint density and a corresponding model trace

transform_fn

transform_fn(transforms, params, invert=False)[source]

Callable that applies a transformation from the transforms dict to values in the params dict and returns the transformed values keyed on the same names.

Parameters:
  • transforms – Dictionary of transforms keyed by names. Names in transforms and params should align.
  • params – Dictionary of arrays keyed by names.
  • invert – Whether to apply the inverse of the transforms.
Returns:

dict of transformed params.

constrain_fn

constrain_fn(model, model_args, model_kwargs, transforms, params)[source]

Gets value at each latent site in model given unconstrained parameters params. The transforms is used to transform these unconstrained parameters to base values of the corresponding priors in model. If a prior is a transformed distribution, the corresponding base value lies in the support of base distribution. Otherwise, the base value lies in the support of the distribution.

Parameters:
  • model – a callable containing NumPyro primitives.
  • model_args (tuple) – args provided to the model.
  • model_kwargs` (dict) – kwargs provided to the model.
  • transforms (dict) – dictionary of transforms keyed by names. Names in transforms and params should align.
  • params (dict) – dictionary of unconstrained values keyed by site names.
Returns:

dict of transformed params.

potential_energy

potential_energy(model, model_args, model_kwargs, inv_transforms, params)[source]

Makes a callable which computes potential energy of a model given unconstrained params. The inv_transforms is used to transform these unconstrained parameters to base values of the corresponding priors in model. If a prior is a transformed distribution, the corresponding base value lies in the support of base distribution. Otherwise, the base value lies in the support of the distribution.

Parameters:
  • model – a callable containing NumPyro primitives.
  • model_args (tuple) – args provided to the model.
  • model_kwargs` (dict) – kwargs provided to the model.
  • inv_transforms (dict) – dictionary of transforms keyed by names.
Returns:

a callable that computes potential energy given unconstrained parameters.

init_to_median

init_to_median(site, num_samples=15, skip_param=False)[source]

Initialize to the prior median.

init_to_prior

init_to_prior(site, skip_param=False)[source]

Initialize to a prior sample.

init_to_uniform

init_to_uniform(site, radius=2, skip_param=False)[source]

Initialize to an arbitrary feasible point, ignoring distribution parameters.

init_to_feasible

init_to_feasible(site, skip_param=False)[source]

Initialize to an arbitrary feasible point, ignoring distribution parameters.

find_valid_initial_params

find_valid_initial_params(rng, model, *model_args, init_strategy=<function init_to_uniform>, param_as_improper=False, prototype_params=None, **model_kwargs)[source]

Given a model with Pyro primitives, returns an initial valid unconstrained parameters. This function also returns an is_valid flag to say whether the initial parameters are valid.

Parameters:
  • rng (jax.random.PRNGKey) – random number generator seed to sample from the prior. The returned init_params will have the batch shape rng.shape[:-1].
  • model – Python callable containing Pyro primitives.
  • *model_args – args provided to the model.
  • init_strategy (callable) – a per-site initialization function.
  • param_as_improper (bool) – a flag to decide whether to consider sites with param statement as sites with improper priors.
  • **model_kwargs – kwargs provided to the model.
Returns:

tuple of (init_params, is_valid).