Utilities¶
enable_validation¶

enable_validation
(is_validate=True)[source]¶ Enable or disable validation checks in NumPyro. Validation checks provide useful warnings and errors, e.g. NaN checks, validating distribution arguments and support values, etc. which is useful for debugging.
Note
This utility does not take effect under JAX’s JIT compilation or vectorized transformation
jax.vmap()
.Parameters: is_validate (bool) – whether to enable validation checks.
validation_enabled¶
set_platform¶
set_host_device_count¶

set_host_device_count
(n)[source]¶ By default, XLA considers all CPU cores as one device. This utility tells XLA that there are n host (CPU) devices available to use. As a consequence, this allows parallel mapping in JAX
jax.pmap()
to work in CPU platform.Note
This utility only takes effect at the beginning of your program. Under the hood, this sets the environment variable XLA_FLAGS=–xla_force_host_platform_device_count=[num_devices], where [num_device] is the desired number of CPU devices n.
Warning
We do not understand much the side effects when using xla_force_host_platform_device_count flag. If you observe some strange phenomenon when using this utility, please let us know through our issue or forum page. Here we quote from XLA source code the meaning of this flag: “Force the host platform to pretend that there are these many host ‘devices’. All of these host devices are backed by the same threadpool. Setting this to anything other than 1 can increase overhead from context switching but we let the user override this behavior to help run tests on the host that run models in parallel across multiple devices.”
Parameters: n (int) – number of CPU devices to use.
Inference Utilities¶
predictive¶

predictive
(rng, model, posterior_samples, *args, num_samples=None, return_sites=None, **kwargs)[source]¶ Run model by sampling latent parameters from posterior_samples, and return values at sample sites from the forward run. By default, only sample sites not contained in posterior_samples are returned. This can be modified by changing the return_sites keyword argument.
Warning
The interface for the predictive function is experimental, and might change in the future.
Parameters:  rng (jax.random.PRNGKey) – seed to draw samples
 model – Python callable containing Pyro primitives.
 posterior_samples (dict) – dictionary of samples from the posterior.
 args – model arguments.
 return_sites (list) – sites to return; by default only sample sites not present in posterior_samples are returned.
 num_samples (int) – number of samples
 kwargs – model kwargs.
Returns: dict of samples from the predictive distribution.
log_density¶

log_density
(model, model_args, model_kwargs, params, skip_dist_transforms=False)[source]¶ Computes log of joint density for the model given latent values
params
.Parameters:  model – Python callable containing NumPyro primitives.
 model_args (tuple) – args provided to the model.
 model_kwargs` (dict) – kwargs provided to the model.
 params (dict) – dictionary of current parameter values keyed by site name.
 skip_dist_transforms (bool) – whether to compute log probability of a site (if its prior is a transformed distribution) in its base distribution domain.
Returns: log of joint density and a corresponding model trace
transform_fn¶

transform_fn
(transforms, params, invert=False)[source]¶ Callable that applies a transformation from the transforms dict to values in the params dict and returns the transformed values keyed on the same names.
Parameters:  transforms – Dictionary of transforms keyed by names. Names in transforms and params should align.
 params – Dictionary of arrays keyed by names.
 invert – Whether to apply the inverse of the transforms.
Returns: dict of transformed params.
constrain_fn¶

constrain_fn
(model, model_args, model_kwargs, transforms, params)[source]¶ Gets value at each latent site in model given unconstrained parameters params. The transforms is used to transform these unconstrained parameters to base values of the corresponding priors in model. If a prior is a transformed distribution, the corresponding base value lies in the support of base distribution. Otherwise, the base value lies in the support of the distribution.
Parameters:  model – a callable containing NumPyro primitives.
 model_args (tuple) – args provided to the model.
 model_kwargs` (dict) – kwargs provided to the model.
 transforms (dict) – dictionary of transforms keyed by names. Names in transforms and params should align.
 params (dict) – dictionary of unconstrained values keyed by site names.
Returns: dict of transformed params.
potential_energy¶

potential_energy
(model, model_args, model_kwargs, inv_transforms, params)[source]¶ Computes potential energy of a model given unconstrained params. The inv_transforms is used to transform these unconstrained parameters to base values of the corresponding priors in model. If a prior is a transformed distribution, the corresponding base value lies in the support of base distribution. Otherwise, the base value lies in the support of the distribution.
Parameters: Returns: potential energy given unconstrained parameters.
log_likelihood¶

log_likelihood
(model, posterior_samples, *args, **kwargs)[source]¶ Returns log likelihood at observation nodes of model, given samples of all latent variables.
Warning
The interface for the log_likelihood function is experimental, and might change in the future.
Parameters:  model – Python callable containing Pyro primitives.
 posterior_samples (dict) – dictionary of samples from the posterior.
 args – model arguments.
 kwargs – model kwargs.
Returns: dict of log likelihoods at observation sites.
find_valid_initial_params¶

find_valid_initial_params
(rng, model, *model_args, init_strategy=functools.partial(<function _init_to_uniform>, radius=2), param_as_improper=False, prototype_params=None, **model_kwargs)[source]¶ Given a model with Pyro primitives, returns an initial valid unconstrained parameters. This function also returns an is_valid flag to say whether the initial parameters are valid.
Parameters:  rng (jax.random.PRNGKey) – random number generator seed to
sample from the prior. The returned init_params will have the
batch shape
rng.shape[:1]
.  model – Python callable containing Pyro primitives.
 *model_args – args provided to the model.
 init_strategy (callable) – a persite initialization function.
 param_as_improper (bool) – a flag to decide whether to consider sites with param statement as sites with improper priors.
 **model_kwargs – kwargs provided to the model.
Returns: tuple of (init_params, is_valid).
 rng (jax.random.PRNGKey) – random number generator seed to
sample from the prior. The returned init_params will have the
batch shape
init_to_median¶
init_to_uniform¶
init_to_feasible¶
init_to_value¶

init_to_value
(values)[source]¶ Initialize to the value specified in values. We defer to
init_to_uniform()
strategy for sites which do not appear in values.Parameters: values (dict) – dictionary of initial values keyed by site name.