Source code for numpyro.distributions.mixtures

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import jax
from jax import lax
import jax.numpy as jnp

from numpyro.distributions import Distribution, constraints
from numpyro.distributions.discrete import CategoricalLogits, CategoricalProbs
from numpyro.distributions.util import validate_sample
from numpyro.util import is_prng_key


[docs] def Mixture(mixing_distribution, component_distributions, *, validate_args=None): """ A marginalized finite mixture of component distributions The returned distribution will be either a: 1. :class:`~numpyro.distributions.MixtureGeneral`, when ``component_distributions`` is a list, or 2. :class:`~numpyro.distributions.MixtureSameFamily`, when ``component_distributions`` is a single distribution. and more details can be found in the documentation for each of these classes. :param mixing_distribution: A :class:`~numpyro.distributions.Categorical` specifying the weights for each mixture components. The size of this distribution specifies the number of components in the mixture, ``mixture_size``. :param component_distributions: Either a list of component distributions or a single vectorized distribution. When a list is provided, the number of elements must equal ``mixture_size``. Otherwise, the last batch dimension of the distribution must equal ``mixture_size``. :return: The mixture distribution. """ if isinstance(component_distributions, Distribution): return MixtureSameFamily( mixing_distribution, component_distributions, validate_args=validate_args ) return MixtureGeneral( mixing_distribution, component_distributions, validate_args=validate_args )
class _MixtureBase(Distribution): """An abstract base class for mixture distributions This consolidates all the shared logic for the mixture distributions, and subclasses should implement the ``component_*`` methods to specialize. """ @property def component_mean(self): raise NotImplementedError @property def component_variance(self): raise NotImplementedError def component_log_probs(self, value): raise NotImplementedError def component_sample(self, key, sample_shape=()): raise NotImplementedError def component_cdf(self, samples): raise NotImplementedError @property def mixture_size(self): """The number of components in the mixture""" return self._mixture_size @property def mixing_distribution(self): """The ``Categorical`` distribution over components""" return self._mixing_distribution @property def mixture_dim(self): return -self.event_dim - 1 @property def mean(self): probs = self.mixing_distribution.probs probs = probs.reshape(probs.shape + (1,) * self.event_dim) weighted_component_means = probs * self.component_mean return jnp.sum(weighted_component_means, axis=self.mixture_dim) @property def variance(self): probs = self.mixing_distribution.probs probs = probs.reshape(probs.shape + (1,) * self.event_dim) mean_cond_var = jnp.sum(probs * self.component_variance, axis=self.mixture_dim) sq_deviation = ( self.component_mean - jnp.expand_dims(self.mean, axis=self.mixture_dim) ) ** 2 var_cond_mean = jnp.sum(probs * sq_deviation, axis=self.mixture_dim) return mean_cond_var + var_cond_mean def cdf(self, samples): """The cumulative distribution function :param value: samples from this distribution. :return: output of the cumulative distribution function evaluated at `value`. :raises: NotImplementedError if the component distribution does not implement the cdf method. """ cdf_components = self.component_cdf(samples) return jnp.sum(cdf_components * self.mixing_distribution.probs, axis=-1) def sample_with_intermediates(self, key, sample_shape=()): """ A version of ``sample`` that also returns the sampled component indices :param jax.random.PRNGKey key: the rng_key key to be used for the distribution. :param tuple sample_shape: the sample shape for the distribution. :return: A 2-element tuple with the samples from the distribution, and the indices of the sampled components. :rtype: tuple """ assert is_prng_key(key) key_comp, key_ind = jax.random.split(key) samples = self.component_sample(key_comp, sample_shape=sample_shape) # Sample selection indices from the categorical (shape will be sample_shape) indices = self.mixing_distribution.expand( sample_shape + self.batch_shape ).sample(key_ind) n_expand = self.event_dim + 1 indices_expanded = indices.reshape(indices.shape + (1,) * n_expand) # Select samples according to indices samples from categorical samples_selected = jnp.take_along_axis( samples, indices=indices_expanded, axis=self.mixture_dim ) # Final sample shape (*sample_shape, *batch_shape, *event_shape) return jnp.squeeze(samples_selected, axis=self.mixture_dim), [indices] def sample(self, key, sample_shape=()): return self.sample_with_intermediates(key=key, sample_shape=sample_shape)[0] @validate_sample def log_prob(self, value, intermediates=None): del intermediates sum_log_probs = self.component_log_probs(value) return jax.nn.logsumexp(sum_log_probs, axis=-1)
[docs] class MixtureSameFamily(_MixtureBase): """ A finite mixture of component distributions from the same family This mixture only supports a mixture of component distributions that are all of the same family. The different components are specified along the last batch dimension of the input ``component_distribution``. If you need a mixture of distributions from different families, use the more general implementation in :class:`~numpyro.distributions.MixtureGeneral`. :param mixing_distribution: A :class:`~numpyro.distributions.Categorical` specifying the weights for each mixture components. The size of this distribution specifies the number of components in the mixture, ``mixture_size``. :param component_distribution: A single vectorized :class:`~numpyro.distributions.Distribution`, whose last batch dimension equals ``mixture_size`` as specified by ``mixing_distribution``. **Example** .. doctest:: >>> import jax >>> import jax.numpy as jnp >>> import numpyro.distributions as dist >>> mixing_dist = dist.Categorical(probs=jnp.ones(3) / 3.) >>> component_dist = dist.Normal(loc=jnp.zeros(3), scale=jnp.ones(3)) >>> mixture = dist.MixtureSameFamily(mixing_dist, component_dist) >>> mixture.sample(jax.random.PRNGKey(42)).shape () """ pytree_data_fields = ("_mixing_distribution", "_component_distribution") pytree_aux_fields = ("_mixture_size",) def __init__( self, mixing_distribution, component_distribution, *, validate_args=None ): _check_mixing_distribution(mixing_distribution) mixture_size = mixing_distribution.probs.shape[-1] if not isinstance(component_distribution, Distribution): raise ValueError( "The component distribution need to be a numpyro.distributions.Distribution. " f"However, it is of type {type(component_distribution)}" ) assert component_distribution.batch_shape[-1] == mixture_size, ( "Component distribution batch shape last dimension " f"(size={component_distribution.batch_shape[-1]}) " f"needs to correspond to the mixture_size={mixture_size}!" ) self._mixing_distribution = mixing_distribution self._component_distribution = component_distribution self._mixture_size = mixture_size batch_shape = lax.broadcast_shapes( mixing_distribution.batch_shape, component_distribution.batch_shape[:-1], # Without probabilities ) super().__init__( batch_shape=batch_shape, event_shape=component_distribution.event_shape, validate_args=validate_args, ) @property def component_distribution(self): """ Return the vectorized distribution of components being mixed. :return: Component distribution :rtype: Distribution """ return self._component_distribution @constraints.dependent_property def support(self): return self.component_distribution.support @property def is_discrete(self): return self.component_distribution.is_discrete @property def component_mean(self): return self.component_distribution.mean @property def component_variance(self): return self.component_distribution.variance
[docs] def component_cdf(self, samples): return self.component_distribution.cdf( jnp.expand_dims(samples, axis=self.mixture_dim) )
[docs] def component_sample(self, key, sample_shape=()): return self.component_distribution.expand( sample_shape + self.batch_shape + (self.mixture_size,) ).sample(key)
[docs] def component_log_probs(self, value): value = jnp.expand_dims(value, self.mixture_dim) component_log_probs = self.component_distribution.log_prob(value) return jax.nn.log_softmax(self.mixing_distribution.logits) + component_log_probs
[docs] class MixtureGeneral(_MixtureBase): """ A finite mixture of component distributions from different families If all of the component distributions are from the same family, the more specific implementation in :class:`~numpyro.distributions.MixtureSameFamily` will be somewhat more efficient. :param mixing_distribution: A :class:`~numpyro.distributions.Categorical` specifying the weights for each mixture components. The size of this distribution specifies the number of components in the mixture, ``mixture_size``. :param component_distributions: A list of ``mixture_size`` :class:`~numpyro.distributions.Distribution` objects. **Example** .. doctest:: >>> import jax >>> import jax.numpy as jnp >>> import numpyro.distributions as dist >>> mixing_dist = dist.Categorical(probs=jnp.ones(3) / 3.) >>> component_dists = [ ... dist.Normal(loc=0.0, scale=1.0), ... dist.Normal(loc=-0.5, scale=0.3), ... dist.Normal(loc=0.6, scale=1.2), ... ] >>> mixture = dist.MixtureGeneral(mixing_dist, component_dists) >>> mixture.sample(jax.random.PRNGKey(42)).shape () """ pytree_data_fields = ("_mixing_distribution", "_component_distributions") pytree_aux_fields = ("_mixture_size",) def __init__( self, mixing_distribution, component_distributions, *, validate_args=None ): _check_mixing_distribution(mixing_distribution) self._mixture_size = jnp.shape(mixing_distribution.probs)[-1] try: component_distributions = list(component_distributions) except TypeError: raise ValueError( "The 'component_distributions' argument must be a list of Distribution objects" ) for d in component_distributions: if not isinstance(d, Distribution): raise ValueError( "All elements of 'component_distributions' must be instaces of " "numpyro.distributions.Distribution subclasses" ) if len(component_distributions) != self.mixture_size: raise ValueError( "The number of elements in 'component_distributions' must match the mixture size; " f"expected {self._mixture_size}, got {len(component_distributions)}" ) # TODO: It would be good to check that the support of all the component # distributions match, but for now we just check the type, since __eq__ # isn't consistently implemented for all support types. support_type = type(component_distributions[0].support) if any( type(d.support) is not support_type for d in component_distributions[1:] ): raise ValueError("All component distributions must have the same support.") self._mixing_distribution = mixing_distribution self._component_distributions = component_distributions batch_shape = lax.broadcast_shapes( mixing_distribution.batch_shape, *(d.batch_shape for d in component_distributions), ) event_shape = component_distributions[0].event_shape for d in component_distributions[1:]: if d.event_shape != event_shape: raise ValueError( "All component distributions must have the same event shape" ) super().__init__( batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, ) @property def component_distributions(self): """The list of component distributions in the mixture :return: The list of component distributions :rtype: list[Distribution] """ return self._component_distributions @constraints.dependent_property def support(self): return self.component_distributions[0].support @property def is_discrete(self): return self.component_distributions[0].is_discrete @property def component_mean(self): return jnp.stack( [d.mean for d in self.component_distributions], axis=self.mixture_dim ) @property def component_variance(self): return jnp.stack( [d.variance for d in self.component_distributions], axis=self.mixture_dim )
[docs] def component_cdf(self, samples): return jnp.stack( [d.cdf(samples) for d in self.component_distributions], axis=self.mixture_dim, )
[docs] def component_sample(self, key, sample_shape=()): keys = jax.random.split(key, self.mixture_size) samples = [] for k, d in zip(keys, self.component_distributions): samples.append(d.expand(sample_shape + self.batch_shape).sample(k)) return jnp.stack(samples, axis=self.mixture_dim)
[docs] def component_log_probs(self, value): component_log_probs = jnp.stack( [d.log_prob(value) for d in self.component_distributions], axis=-1 ) return jax.nn.log_softmax(self.mixing_distribution.logits) + component_log_probs
def _check_mixing_distribution(mixing_distribution): if not isinstance(mixing_distribution, (CategoricalLogits, CategoricalProbs)): raise ValueError( "The mixing distribution must be a numpyro.distributions.Categorical. " f"However, it is of type {type(mixing_distribution)}" )