# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
# The implementation follows the design in PyTorch: torch.distributions.distribution.py
#
# Copyright (c) 2016- Facebook, Inc (Adam Paszke)
# Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
# Copyright (c) 2011-2013 NYU (Clement Farabet)
# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
# Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
from collections import OrderedDict
from contextlib import contextmanager
import warnings
from jax import lax, tree_util
import jax.numpy as jnp
from numpyro.distributions.constraints import is_dependent, real
from numpyro.distributions.transforms import Transform
from numpyro.distributions.util import lazy_property, promote_shapes, sum_rightmost, validate_sample
from numpyro.util import not_jax_tracer
_VALIDATION_ENABLED = False
[docs]def enable_validation(is_validate=True):
"""
Enable or disable validation checks in NumPyro. Validation checks provide useful warnings and
errors, e.g. NaN checks, validating distribution arguments and support values, etc. which is
useful for debugging.
.. note:: This utility does not take effect under JAX's JIT compilation or vectorized
transformation :func:`jax.vmap`.
:param bool is_validate: whether to enable validation checks.
"""
global _VALIDATION_ENABLED
_VALIDATION_ENABLED = is_validate
Distribution.set_default_validate_args(is_validate)
[docs]@contextmanager
def validation_enabled(is_validate=True):
"""
Context manager that is useful when temporarily enabling/disabling validation checks.
:param bool is_validate: whether to enable validation checks.
"""
distribution_validation_status = _VALIDATION_ENABLED
try:
enable_validation(is_validate)
yield
finally:
enable_validation(distribution_validation_status)
[docs]class Distribution(object):
"""
Base class for probability distributions in NumPyro. The design largely
follows from :mod:`torch.distributions`.
:param batch_shape: The batch shape for the distribution. This designates
independent (possibly non-identical) dimensions of a sample from the
distribution. This is fixed for a distribution instance and is inferred
from the shape of the distribution parameters.
:param event_shape: The event shape for the distribution. This designates
the dependent dimensions of a sample from the distribution. These are
collapsed when we evaluate the log probability density of a batch of
samples using `.log_prob`.
:param validate_args: Whether to enable validation of distribution
parameters and arguments to `.log_prob` method.
As an example:
.. doctest::
>>> import jax.numpy as jnp
>>> import numpyro.distributions as dist
>>> d = dist.Dirichlet(jnp.ones((2, 3, 4)))
>>> d.batch_shape
(2, 3)
>>> d.event_shape
(4,)
"""
arg_constraints = {}
support = None
has_enumerate_support = False
is_discrete = False
reparametrized_params = []
_validate_args = False
# register Distribution as a pytree
# ref: https://github.com/google/jax/issues/2916
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
tree_util.register_pytree_node(cls,
cls.tree_flatten,
cls.tree_unflatten)
[docs] def tree_flatten(self):
return tuple(getattr(self, param) for param in sorted(self.arg_constraints.keys())), None
[docs] @classmethod
def tree_unflatten(cls, aux_data, params):
return cls(**dict(zip(sorted(cls.arg_constraints.keys()), params)))
[docs] @staticmethod
def set_default_validate_args(value):
if value not in [True, False]:
raise ValueError
Distribution._validate_args = value
def __init__(self, batch_shape=(), event_shape=(), validate_args=None):
self._batch_shape = batch_shape
self._event_shape = event_shape
if validate_args is not None:
self._validate_args = validate_args
if self._validate_args:
for param, constraint in self.arg_constraints.items():
if param not in self.__dict__ and isinstance(getattr(type(self), param), lazy_property):
continue
if is_dependent(constraint):
continue # skip constraints that cannot be checked
is_valid = jnp.all(constraint(getattr(self, param)))
if not_jax_tracer(is_valid):
if not is_valid:
raise ValueError("The parameter {} has invalid values".format(param))
super(Distribution, self).__init__()
@property
def batch_shape(self):
"""
Returns the shape over which the distribution parameters are batched.
:return: batch shape of the distribution.
:rtype: tuple
"""
return self._batch_shape
@property
def event_shape(self):
"""
Returns the shape of a single sample from the distribution without
batching.
:return: event shape of the distribution.
:rtype: tuple
"""
return self._event_shape
@property
def event_dim(self):
"""
:return: Number of dimensions of individual events.
:rtype: int
"""
return len(self.event_shape)
[docs] def shape(self, sample_shape=()):
"""
The tensor shape of samples from this distribution.
Samples are of shape::
d.shape(sample_shape) == sample_shape + d.batch_shape + d.event_shape
:param tuple sample_shape: the size of the iid batch to be drawn from the
distribution.
:return: shape of samples.
:rtype: tuple
"""
return sample_shape + self.batch_shape + self.event_shape
[docs] def sample(self, key, sample_shape=()):
"""
Returns a sample from the distribution having shape given by
`sample_shape + batch_shape + event_shape`. Note that when `sample_shape` is non-empty,
leading dimensions (of size `sample_shape`) of the returned sample will
be filled with iid draws from the distribution instance.
:param jax.random.PRNGKey key: the rng_key key to be used for the distribution.
:param tuple sample_shape: the sample shape for the distribution.
:return: an array of shape `sample_shape + batch_shape + event_shape`
:rtype: numpy.ndarray
"""
raise NotImplementedError
[docs] def log_prob(self, value):
"""
Evaluates the log probability density for a batch of samples given by
`value`.
:param value: A batch of samples from the distribution.
:return: an array with shape `value.shape[:-self.event_shape]`
:rtype: numpy.ndarray
"""
raise NotImplementedError
@property
def mean(self):
"""
Mean of the distribution.
"""
raise NotImplementedError
@property
def variance(self):
"""
Variance of the distribution.
"""
raise NotImplementedError
def _validate_sample(self, value):
mask = self.support(value)
if not_jax_tracer(mask):
if not jnp.all(mask):
warnings.warn('Out-of-support values provided to log prob method. '
'The value argument should be within the support.')
return mask
def __call__(self, *args, **kwargs):
key = kwargs.pop('rng_key')
sample_intermediates = kwargs.pop('sample_intermediates', False)
if sample_intermediates:
return self.sample_with_intermediates(key, *args, **kwargs)
return self.sample(key, *args, **kwargs)
[docs] def to_event(self, reinterpreted_batch_ndims=None):
"""
Interpret the rightmost `reinterpreted_batch_ndims` batch dimensions as
dependent event dimensions.
:param reinterpreted_batch_ndims: Number of rightmost batch dims to
interpret as event dims.
:return: An instance of `Independent` distribution.
:rtype: Independent
"""
if reinterpreted_batch_ndims is None:
reinterpreted_batch_ndims = len(self.batch_shape)
elif reinterpreted_batch_ndims == 0:
return self
return Independent(self, reinterpreted_batch_ndims)
[docs] def enumerate_support(self, expand=True):
"""
Returns an array with shape `len(support) x batch_shape`
containing all values in the support.
"""
raise NotImplementedError
[docs] def expand(self, batch_shape):
"""
Returns a new :class:`ExpandedDistribution` instance with batch
dimensions expanded to `batch_shape`.
:param tuple batch_shape: batch shape to expand to.
:return: an instance of `ExpandedDistribution`.
:rtype: :class:`ExpandedDistribution`
"""
batch_shape = tuple(batch_shape)
if batch_shape == self.batch_shape:
return self
return ExpandedDistribution(self, batch_shape)
[docs] def expand_by(self, sample_shape):
"""
Expands a distribution by adding ``sample_shape`` to the left side of
its :attr:`~numpyro.distributions.distribution.Distribution.batch_shape`.
To expand internal dims of ``self.batch_shape`` from 1 to something
larger, use :meth:`expand` instead.
:param tuple sample_shape: The size of the iid batch to be drawn
from the distribution.
:return: An expanded version of this distribution.
:rtype: :class:`ExpandedDistribution`
"""
return self.expand(tuple(sample_shape) + self._batch_shape)
[docs] def mask(self, mask):
"""
Masks a distribution by a boolean or boolean-valued array that is
broadcastable to the distributions
:attr:`Distribution.batch_shape` .
:param mask: A boolean or boolean valued array (`True` includes
a site, `False` excludes a site).
:type mask: bool or jnp.ndarray
:return: A masked copy of this distribution.
:rtype: :class:`MaskedDistribution`
"""
if mask is True:
return self
return MaskedDistribution(self, mask)
[docs]class ExpandedDistribution(Distribution):
arg_constraints = {}
def __init__(self, base_dist, batch_shape=()):
if isinstance(base_dist, ExpandedDistribution):
batch_shape = self._broadcast_shape(base_dist.batch_shape, batch_shape)
base_dist = base_dist.base_dist
self.base_dist = base_dist
super().__init__(base_dist.batch_shape, base_dist.event_shape)
# adjust batch shape
self.expand(batch_shape)
[docs] def expand(self, batch_shape):
# Do basic validation. e.g. we should not "unexpand" distributions even if that is possible.
new_shape, _, _ = self._broadcast_shape(self.batch_shape, batch_shape)
# Record interstitial and expanded dims/sizes w.r.t. the base distribution
new_shape, expanded_sizes, interstitial_sizes = self._broadcast_shape(self.base_dist.batch_shape,
new_shape)
self._batch_shape = new_shape
self._expanded_sizes = expanded_sizes
self._interstitial_sizes = interstitial_sizes
return self
@staticmethod
def _broadcast_shape(existing_shape, new_shape):
if len(new_shape) < len(existing_shape):
raise ValueError("Cannot broadcast distribution of shape {} to shape {}"
.format(existing_shape, new_shape))
reversed_shape = list(reversed(existing_shape))
expanded_sizes, interstitial_sizes = [], []
for i, size in enumerate(reversed(new_shape)):
if i >= len(reversed_shape):
reversed_shape.append(size)
expanded_sizes.append((-i - 1, size))
elif reversed_shape[i] == 1:
if size != 1:
reversed_shape[i] = size
interstitial_sizes.append((-i - 1, size))
elif reversed_shape[i] != size:
raise ValueError("Cannot broadcast distribution of shape {} to shape {}"
.format(existing_shape, new_shape))
return tuple(reversed(reversed_shape)), OrderedDict(expanded_sizes), OrderedDict(interstitial_sizes)
@property
def has_enumerate_support(self):
return self.base_dist.has_enumerate_support
@property
def is_discrete(self):
return self.base_dist.is_discrete
@property
def support(self):
return self.base_dist.support
[docs] def sample(self, key, sample_shape=()):
interstitial_dims = tuple(self._interstitial_sizes.keys())
event_dim = len(self.event_shape)
interstitial_dims = tuple(i - event_dim for i in interstitial_dims)
interstitial_sizes = tuple(self._interstitial_sizes.values())
expanded_sizes = tuple(self._expanded_sizes.values())
batch_shape = expanded_sizes + interstitial_sizes
samples = self.base_dist.sample(key, sample_shape + batch_shape)
interstitial_idx = len(sample_shape) + len(expanded_sizes)
interstitial_sample_dims = tuple(range(interstitial_idx, interstitial_idx + len(interstitial_sizes)))
for dim1, dim2 in zip(interstitial_dims, interstitial_sample_dims):
samples = jnp.swapaxes(samples, dim1, dim2)
return samples.reshape(sample_shape + self.batch_shape + self.event_shape)
[docs] def log_prob(self, value):
shape = lax.broadcast_shapes(self.batch_shape,
jnp.shape(value)[:max(jnp.ndim(value) - self.event_dim, 0)])
log_prob = self.base_dist.log_prob(value)
return jnp.broadcast_to(log_prob, shape)
[docs] def enumerate_support(self, expand=True):
samples = self.base_dist.enumerate_support(expand=False)
enum_shape = samples.shape[:1]
samples = samples.reshape(enum_shape + (1,) * len(self.batch_shape))
if expand:
samples = samples.expand(enum_shape + self.batch_shape)
return samples
@property
def mean(self):
return jnp.broadcast_to(self.base_dist.mean, self.batch_shape + self.event_shape)
@property
def variance(self):
return jnp.broadcast_to(self.base_dist.variance, self.batch_shape + self.event_shape)
[docs] def tree_flatten(self):
prepend_ndim = len(self.batch_shape) - len(self.base_dist.batch_shape)
base_dist = tree_util.tree_map(
lambda x: promote_shapes(x, shape=(1,) * prepend_ndim + jnp.shape(x))[0],
self.base_dist)
base_flatten, base_aux = base_dist.tree_flatten()
return base_flatten, (type(self.base_dist), base_aux, self.batch_shape)
[docs] @classmethod
def tree_unflatten(cls, aux_data, params):
base_cls, base_aux, batch_shape = aux_data
base_dist = base_cls.tree_unflatten(base_aux, params)
prepend_shape = base_dist.batch_shape[:len(base_dist.batch_shape) - len(batch_shape)]
return cls(base_dist, batch_shape=prepend_shape + batch_shape)
[docs]class Independent(Distribution):
"""
Reinterprets batch dimensions of a distribution as event dims by shifting
the batch-event dim boundary further to the left.
From a practical standpoint, this is useful when changing the result of
:meth:`log_prob`. For example, a univariate Normal distribution can be
interpreted as a multivariate Normal with diagonal covariance:
.. doctest::
>>> import numpyro.distributions as dist
>>> normal = dist.Normal(jnp.zeros(3), jnp.ones(3))
>>> [normal.batch_shape, normal.event_shape]
[(3,), ()]
>>> diag_normal = dist.Independent(normal, 1)
>>> [diag_normal.batch_shape, diag_normal.event_shape]
[(), (3,)]
:param numpyro.distribution.Distribution base_distribution: a distribution instance.
:param int reinterpreted_batch_ndims: the number of batch dims to reinterpret as event dims.
"""
arg_constraints = {}
def __init__(self, base_dist, reinterpreted_batch_ndims, validate_args=None):
if reinterpreted_batch_ndims > len(base_dist.batch_shape):
raise ValueError("Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), "
"actual {} vs {}".format(reinterpreted_batch_ndims,
len(base_dist.batch_shape)))
shape = base_dist.batch_shape + base_dist.event_shape
event_dim = reinterpreted_batch_ndims + len(base_dist.event_shape)
batch_shape = shape[:len(shape) - event_dim]
event_shape = shape[len(shape) - event_dim:]
self.base_dist = base_dist
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
super(Independent, self).__init__(batch_shape, event_shape, validate_args=validate_args)
@property
def support(self):
return self.base_dist.support
@property
def has_enumerate_support(self):
return self.base_dist.has_enumerate_support
@property
def is_discrete(self):
return self.base_dist.is_discrete
@property
def reparameterized_params(self):
return self.base_dist.reparameterized_params
@property
def mean(self):
return self.base_dist.mean
@property
def variance(self):
return self.base_dist.variance
[docs] def sample(self, key, sample_shape=()):
return self.base_dist.sample(key, sample_shape=sample_shape)
[docs] def log_prob(self, value):
log_prob = self.base_dist.log_prob(value)
return sum_rightmost(log_prob, self.reinterpreted_batch_ndims)
[docs] def expand(self, batch_shape):
base_batch_shape = batch_shape + self.event_shape[:self.reinterpreted_batch_ndims]
return self.base_dist.expand(base_batch_shape).to_event(self.reinterpreted_batch_ndims)
[docs] def tree_flatten(self):
base_flatten, base_aux = self.base_dist.tree_flatten()
return base_flatten, (type(self.base_dist), base_aux, self.reinterpreted_batch_ndims)
[docs] @classmethod
def tree_unflatten(cls, aux_data, params):
base_cls, base_aux, reinterpreted_batch_ndims = aux_data
base_dist = base_cls.tree_unflatten(base_aux, params)
return cls(base_dist, reinterpreted_batch_ndims)
[docs]class MaskedDistribution(Distribution):
"""
Masks a distribution by a boolean array that is broadcastable to the
distribution's :attr:`Distribution.batch_shape`.
In the special case ``mask is False``, computation of :meth:`log_prob` , is skipped,
and constant zero values are returned instead.
:param mask: A boolean or boolean-valued array.
:type mask: jnp.ndarray or bool
"""
arg_constraints = {}
def __init__(self, base_dist, mask):
if isinstance(mask, bool):
self._mask = mask
else:
batch_shape = lax.broadcast_shapes(jnp.shape(mask), base_dist.batch_shape)
if mask.shape != batch_shape:
mask = jnp.broadcast_to(mask, batch_shape)
if base_dist.batch_shape != batch_shape:
base_dist = base_dist.expand(batch_shape)
self._mask = mask.astype('bool')
self.base_dist = base_dist
super().__init__(base_dist.batch_shape, base_dist.event_shape)
@property
def has_enumerate_support(self):
return self.base_dist.has_enumerate_support
@property
def is_discrete(self):
return self.base_dist.is_discrete
@property
def support(self):
return self.base_dist.support
[docs] def sample(self, key, sample_shape=()):
return self.base_dist.sample(key, sample_shape)
[docs] def log_prob(self, value):
if self._mask is False:
shape = lax.broadcast_shapes(self.base_dist.batch_shape,
jnp.shape(value)[:max(jnp.ndim(value) - len(self.event_shape), 0)])
return jnp.zeros(shape)
if self._mask is True:
return self.base_dist.log_prob(value)
return jnp.where(self._mask, self.base_dist.log_prob(value), 0.)
[docs] def enumerate_support(self, expand=True):
return self.base_dist.enumerate_support(expand=expand)
@property
def mean(self):
return self.base_dist.mean
@property
def variance(self):
return self.base_dist.variance
[docs] def tree_flatten(self):
base_flatten, base_aux = self.base_dist.tree_flatten()
if isinstance(self._mask, bool):
return base_flatten, (type(self.base_dist), base_aux, self._mask)
else:
return (base_flatten, self._mask), (type(self.base_dist), base_aux)
[docs] @classmethod
def tree_unflatten(cls, aux_data, params):
if len(aux_data) == 2:
base_flatten, mask = params
base_cls, base_aux = aux_data
else:
base_flatten = params
base_cls, base_aux, mask = aux_data
base_dist = base_cls.tree_unflatten(base_aux, base_flatten)
return cls(base_dist, mask)
[docs]class Unit(Distribution):
"""
Trivial nonnormalized distribution representing the unit type.
The unit type has a single value with no data, i.e. ``value.size == 0``.
This is used for :func:`numpyro.factor` statements.
"""
arg_constraints = {'log_factor': real}
support = real
def __init__(self, log_factor, validate_args=None):
batch_shape = jnp.shape(log_factor)
event_shape = (0,) # This satisfies .size == 0.
self.log_factor = log_factor
super(Unit, self).__init__(batch_shape, event_shape, validate_args=validate_args)
[docs] def sample(self, key, sample_shape=()):
return jnp.empty(sample_shape + self.batch_shape + self.event_shape)
[docs] def log_prob(self, value):
shape = lax.broadcast_shapes(self.batch_shape, jnp.shape(value)[:-1])
return jnp.broadcast_to(self.log_factor, shape)