# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from collections import namedtuple
from contextlib import ExitStack, contextmanager
import functools
from jax import lax, random
import jax.numpy as jnp
import numpyro
from numpyro.distributions.discrete import PRNGIdentity
from numpyro.util import identity
_PYRO_STACK = []
CondIndepStackFrame = namedtuple('CondIndepStackFrame', ['name', 'dim', 'size'])
def apply_stack(msg):
pointer = 0
for pointer, handler in enumerate(reversed(_PYRO_STACK)):
handler.process_message(msg)
# When a Messenger sets the "stop" field of a message,
# it prevents any Messengers above it on the stack from being applied.
if msg.get("stop"):
break
if msg['value'] is None:
if msg['type'] == 'sample':
msg['value'], msg['intermediates'] = msg['fn'](*msg['args'],
sample_intermediates=True,
**msg['kwargs'])
else:
msg['value'] = msg['fn'](*msg['args'], **msg['kwargs'])
# A Messenger that sets msg["stop"] == True also prevents application
# of postprocess_message by Messengers above it on the stack
# via the pointer variable from the process_message loop
for handler in _PYRO_STACK[-pointer-1:]:
handler.postprocess_message(msg)
return msg
class Messenger(object):
def __init__(self, fn=None):
if fn is not None and not callable(fn):
raise ValueError("Expected `fn` to be a Python callable object; "
"instead found type(fn) = {}.".format(type(fn)))
self.fn = fn
functools.update_wrapper(self, fn, updated=[])
def __enter__(self):
_PYRO_STACK.append(self)
def __exit__(self, *args, **kwargs):
assert _PYRO_STACK[-1] is self
_PYRO_STACK.pop()
def process_message(self, msg):
pass
def postprocess_message(self, msg):
pass
def __call__(self, *args, **kwargs):
with self:
return self.fn(*args, **kwargs)
[docs]def sample(name, fn, obs=None, rng_key=None, sample_shape=(), infer=None):
"""
Returns a random sample from the stochastic function `fn`. This can have
additional side effects when wrapped inside effect handlers like
:class:`~numpyro.handlers.substitute`.
.. note::
By design, `sample` primitive is meant to be used inside a NumPyro model.
Then :class:`~numpyro.handlers.seed` handler is used to inject a random
state to `fn`. In those situations, `rng_key` keyword will take no
effect.
:param str name: name of the sample site.
:param fn: a stochastic function that returns a sample.
:param numpy.ndarray obs: observed value
:param jax.random.PRNGKey rng_key: an optional random key for `fn`.
:param sample_shape: Shape of samples to be drawn.
:param dict infer: an optional dictionary containing additional information
for inference algorithms. For example, if `fn` is a discrete distribution,
setting `infer={'enumerate': 'parallel'}` to tell MCMC marginalize
this discrete latent site.
:return: sample from the stochastic `fn`.
"""
# if there are no active Messengers, we just draw a sample and return it as expected:
if not _PYRO_STACK:
return fn(rng_key=rng_key, sample_shape=sample_shape)
# Otherwise, we initialize a message...
initial_msg = {
'type': 'sample',
'name': name,
'fn': fn,
'args': (),
'kwargs': {'rng_key': rng_key, 'sample_shape': sample_shape},
'value': obs,
'scale': None,
'is_observed': obs is not None,
'intermediates': [],
'cond_indep_stack': [],
'infer': {} if infer is None else infer,
}
# ...and use apply_stack to send it to the Messengers
msg = apply_stack(initial_msg)
return msg['value']
[docs]def param(name, init_value=None, **kwargs):
"""
Annotate the given site as an optimizable parameter for use with
:mod:`jax.experimental.optimizers`. For an example of how `param` statements
can be used in inference algorithms, refer to :func:`~numpyro.svi.svi`.
:param str name: name of site.
:param numpy.ndarray init_value: initial value specified by the user. Note that
the onus of using this to initialize the optimizer is on the user /
inference algorithm, since there is no global parameter store in
NumPyro.
:param constraint: NumPyro constraint, defaults to ``constraints.real``.
:type constraint: numpyro.distributions.constraints.Constraint
:param int event_dim: (optional) number of rightmost dimensions unrelated
to batching. Dimension to the left of this will be considered batch
dimensions; if the param statement is inside a subsampled plate, then
corresponding batch dimensions of the parameter will be correspondingly
subsampled. If unspecified, all dimensions will be considered event
dims and no subsampling will be performed.
:return: value for the parameter. Unless wrapped inside a
handler like :class:`~numpyro.handlers.substitute`, this will simply
return the initial value.
"""
# if there are no active Messengers, we just draw a sample and return it as expected:
if not _PYRO_STACK:
return init_value
# Otherwise, we initialize a message...
initial_msg = {
'type': 'param',
'name': name,
'fn': identity,
'args': (init_value,),
'kwargs': kwargs,
'value': None,
'scale': None,
'cond_indep_stack': [],
}
# ...and use apply_stack to send it to the Messengers
msg = apply_stack(initial_msg)
return msg['value']
[docs]def deterministic(name, value):
"""
Used to designate deterministic sites in the model. Note that most effect
handlers will not operate on deterministic sites (except
:func:`~numpyro.handlers.trace`), so deterministic sites should be
side-effect free. The use case for deterministic nodes is to record any
values in the model execution trace.
:param str name: name of the deterministic site.
:param numpy.ndarray value: deterministic value to record in the trace.
"""
if not _PYRO_STACK:
return value
initial_msg = {
'type': 'deterministic',
'name': name,
'value': value,
}
# ...and use apply_stack to send it to the Messengers
msg = apply_stack(initial_msg)
return msg['value']
[docs]def module(name, nn, input_shape=None):
"""
Declare a :mod:`~jax.experimental.stax` style neural network inside a
model so that its parameters are registered for optimization via
:func:`~numpyro.primitives.param` statements.
:param str name: name of the module to be registered.
:param tuple nn: a tuple of `(init_fn, apply_fn)` obtained by a :mod:`~jax.experimental.stax`
constructor function.
:param tuple input_shape: shape of the input taken by the
neural network.
:return: a `apply_fn` with bound parameters that takes an array
as an input and returns the neural network transformed output
array.
"""
module_key = name + '$params'
nn_init, nn_apply = nn
nn_params = param(module_key)
if nn_params is None:
if input_shape is None:
raise ValueError('Valid value for `input_shape` needed to initialize.')
rng_key = numpyro.sample(name + '$rng_key', PRNGIdentity())
_, nn_params = nn_init(rng_key, input_shape)
param(module_key, nn_params)
return functools.partial(nn_apply, nn_params)
def _subsample_fn(size, subsample_size, rng_key=None):
assert rng_key is not None, "Missing random key to generate subsample indices."
return random.permutation(rng_key, size)[:subsample_size]
[docs]class plate(Messenger):
"""
Construct for annotating conditionally independent variables. Within a
`plate` context manager, `sample` sites will be automatically broadcasted to
the size of the plate. Additionally, a scale factor might be applied by
certain inference algorithms if `subsample_size` is specified.
.. note:: This can be used to subsample minibatches of data:
.. code-block:: python
with plate("data", len(data), subsample_size=100) as ind:
batch = data[ind]
assert len(batch) == 100
:param str name: Name of the plate.
:param int size: Size of the plate.
:param int subsample_size: Optional argument denoting the size of the mini-batch.
This can be used to apply a scaling factor by inference algorithms. e.g.
when computing ELBO using a mini-batch.
:param int dim: Optional argument to specify which dimension in the tensor
is used as the plate dim. If `None` (default), the leftmost available dim
is allocated.
"""
def __init__(self, name, size, subsample_size=None, dim=None):
self.name = name
self.size = size
if dim is not None and dim >= 0:
raise ValueError('dim arg must be negative.')
self.dim, self._indices = self._subsample(
self.name, self.size, subsample_size, dim)
self.subsample_size = self._indices.shape[0]
super(plate, self).__init__()
# XXX: different from Pyro, this method returns dim and indices
@staticmethod
def _subsample(name, size, subsample_size, dim):
msg = {
'type': 'plate',
'fn': _subsample_fn,
'name': name,
'args': (size, subsample_size),
'kwargs': {'rng_key': None},
'value': (None
if (subsample_size is not None and size != subsample_size)
else jnp.arange(size)),
'scale': 1.0,
'cond_indep_stack': [],
}
apply_stack(msg)
subsample = msg['value']
if subsample_size is not None and subsample_size != subsample.shape[0]:
raise ValueError("subsample_size does not match len(subsample), {} vs {}.".format(
subsample_size, len(subsample)) +
" Did you accidentally use different subsample_size in the model and guide?")
cond_indep_stack = msg['cond_indep_stack']
occupied_dims = {f.dim for f in cond_indep_stack}
if dim is None:
new_dim = -1
while new_dim in occupied_dims:
new_dim -= 1
dim = new_dim
else:
assert dim not in occupied_dims
return dim, subsample
def __enter__(self):
super().__enter__()
return self._indices
@staticmethod
def _get_batch_shape(cond_indep_stack):
n_dims = max(-f.dim for f in cond_indep_stack)
batch_shape = [1] * n_dims
for f in cond_indep_stack:
batch_shape[f.dim] = f.size
return tuple(batch_shape)
def process_message(self, msg):
if msg['type'] not in ('param', 'sample', 'plate'):
if msg['type'] == 'control_flow':
raise NotImplementedError('Cannot use control flow primitive under a `plate` primitive.'
' Please move those `plate` statements into the control flow'
' body function. See `scan` documentation for more information.')
return
cond_indep_stack = msg['cond_indep_stack']
frame = CondIndepStackFrame(self.name, self.dim, self.subsample_size)
cond_indep_stack.append(frame)
if msg['type'] == 'sample':
expected_shape = self._get_batch_shape(cond_indep_stack)
dist_batch_shape = msg['fn'].batch_shape
if 'sample_shape' in msg['kwargs']:
dist_batch_shape = msg['kwargs']['sample_shape'] + dist_batch_shape
msg['kwargs']['sample_shape'] = ()
overlap_idx = max(len(expected_shape) - len(dist_batch_shape), 0)
trailing_shape = expected_shape[overlap_idx:]
broadcast_shape = lax.broadcast_shapes(trailing_shape, tuple(dist_batch_shape))
batch_shape = expected_shape[:overlap_idx] + broadcast_shape
msg['fn'] = msg['fn'].expand(batch_shape)
if self.size != self.subsample_size:
scale = 1. if msg['scale'] is None else msg['scale']
msg['scale'] = scale * self.size / self.subsample_size
def postprocess_message(self, msg):
if msg["type"] in ("subsample", "param") and self.dim is not None:
event_dim = msg["kwargs"].get("event_dim")
if event_dim is not None:
assert event_dim >= 0
dim = self.dim - event_dim
shape = jnp.shape(msg["value"])
if len(shape) >= -dim and shape[dim] != 1:
if shape[dim] != self.size:
if msg["type"] == "param":
statement = "numpyro.param({}, ..., event_dim={})".format(msg["name"], event_dim)
else:
statement = "numpyro.subsample(..., event_dim={})".format(event_dim)
raise ValueError(
"Inside numpyro.plate({}, {}, dim={}) invalid shape of {}: {}"
.format(self.name, self.size, self.dim, statement, shape))
if self.subsample_size < self.size:
value = msg["value"]
new_value = jnp.take(value, self._indices, dim)
msg["value"] = new_value
[docs]@contextmanager
def plate_stack(prefix, sizes, rightmost_dim=-1):
"""
Create a contiguous stack of :class:`plate` s with dimensions::
rightmost_dim - len(sizes), ..., rightmost_dim
:param str prefix: Name prefix for plates.
:param iterable sizes: An iterable of plate sizes.
:param int rightmost_dim: The rightmost dim, counting from the right.
"""
assert rightmost_dim < 0
with ExitStack() as stack:
for i, size in enumerate(reversed(sizes)):
plate_i = plate("{}_{}".format(prefix, i), size, dim=rightmost_dim - i)
stack.enter_context(plate_i)
yield
[docs]def factor(name, log_factor):
"""
Factor statement to add arbitrary log probability factor to a
probabilistic model.
:param str name: Name of the trivial sample.
:param numpy.ndarray log_factor: A possibly batched log probability factor.
"""
unit_dist = numpyro.distributions.distribution.Unit(log_factor)
unit_value = unit_dist.sample(None)
sample(name, unit_dist, obs=unit_value)
[docs]def subsample(data, event_dim):
"""
EXPERIMENTAL Subsampling statement to subsample data based on enclosing
:class:`~numpyro.primitives.plate` s.
This is typically called on arguments to ``model()`` when subsampling is
performed automatically by :class:`~numpyro.primitives.plate` s by passing
``subsample_size`` kwarg. For example the following are equivalent::
# Version 1. using indexing
def model(data):
with numpyro.plate("data", len(data), subsample_size=10, dim=-data.dim()) as ind:
data = data[ind]
# ...
# Version 2. using numpyro.subsample()
def model(data):
with numpyro.plate("data", len(data), subsample_size=10, dim=-data.dim()):
data = numpyro.subsample(data, event_dim=0)
# ...
:param numpy.ndarray data: A tensor of batched data.
:param int event_dim: The event dimension of the data tensor. Dimensions to
the left are considered batch dimensions.
:returns: A subsampled version of ``data``
:rtype: ~numpy.ndarray
"""
if not _PYRO_STACK:
return data
assert isinstance(event_dim, int) and event_dim >= 0
initial_msg = {
'type': 'subsample',
'value': data,
'kwargs': {'event_dim': event_dim}
}
msg = apply_stack(initial_msg)
return msg['value']