# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from collections import OrderedDict, namedtuple
from contextlib import ExitStack # python 3
from enum import Enum
from jax import lax
import jax.numpy as jnp
import funsor
from numpyro.handlers import infer_config
from numpyro.handlers import trace as OrigTraceMessenger
from numpyro.primitives import Messenger, apply_stack
from numpyro.primitives import plate as OrigPlateMessenger
funsor.set_backend("jax")
__all__ = [
"enum",
"infer_config",
"markov",
"plate",
"to_data",
"to_funsor",
"trace",
]
##################################
# DimStack to store global state
##################################
# name_to_dim : dict, dim_to_name : dict, parents : tuple, iter_parents : tuple
class StackFrame(namedtuple('StackFrame', [
'name_to_dim',
'dim_to_name',
'parents',
'iter_parents',
'keep',
])):
def read(self, name, dim):
found_name = self.dim_to_name.get(dim, name)
found_dim = self.name_to_dim.get(name, dim)
found = name in self.name_to_dim or dim in self.dim_to_name
return found_name, found_dim, found
def write(self, name, dim):
assert name is not None and dim is not None
self.dim_to_name[dim] = name
self.name_to_dim[name] = dim
def free(self, name, dim):
self.dim_to_name.pop(dim, None)
self.name_to_dim.pop(name, None)
return name, dim
class DimType(Enum):
"""Enumerates the possible types of dimensions to allocate"""
LOCAL = 0
GLOBAL = 1
VISIBLE = 2
DimRequest = namedtuple('DimRequest', ['dim', 'dim_type'])
DimRequest.__new__.__defaults__ = (None, DimType.LOCAL)
NameRequest = namedtuple('NameRequest', ['name', 'dim_type'])
NameRequest.__new__.__defaults__ = (None, DimType.LOCAL)
class DimStack:
"""
Single piece of global state to keep track of the mapping between names and dimensions.
Replaces the plate DimAllocator, the enum EnumAllocator, the stack in MarkovMessenger,
_param_dims and _value_dims in EnumMessenger, and dim_to_symbol in msg['infer']
"""
def __init__(self):
self._stack = [StackFrame(
name_to_dim=OrderedDict(), dim_to_name=OrderedDict(),
parents=(), iter_parents=(), keep=False,
)]
self._first_available_dim = -1
self.outermost = None
MAX_DIM = -25
def set_first_available_dim(self, dim):
assert dim is None or (self.MAX_DIM < dim < 0)
old_dim, self._first_available_dim = self._first_available_dim, dim
return old_dim
def push(self, frame):
self._stack.append(frame)
def pop(self):
assert len(self._stack) > 1, "cannot pop the global frame"
return self._stack.pop()
@property
def current_frame(self):
return self._stack[-1]
@property
def global_frame(self):
return self._stack[0]
def _gendim(self, name_request, dim_request):
assert isinstance(name_request, NameRequest) and isinstance(dim_request, DimRequest)
dim_type = dim_request.dim_type
if name_request.name is None:
fresh_name = f"_pyro_dim_{-dim_request.dim}"
else:
fresh_name = name_request.name
conflict_frames = (self.current_frame, self.global_frame) + \
self.current_frame.parents + self.current_frame.iter_parents
if dim_request.dim is None:
fresh_dim = self._first_available_dim if dim_type != DimType.VISIBLE else -1
fresh_dim = -1 if fresh_dim is None else fresh_dim
while any(fresh_dim in p.dim_to_name for p in conflict_frames):
fresh_dim -= 1
else:
fresh_dim = dim_request.dim
if fresh_dim < self.MAX_DIM or \
any(fresh_dim in p.dim_to_name for p in conflict_frames) or \
(dim_type == DimType.VISIBLE and fresh_dim <= self._first_available_dim):
raise ValueError(f"Ran out of free dims during allocation for {fresh_name}")
return fresh_name, fresh_dim
def request(self, name, dim):
assert isinstance(name, NameRequest) ^ isinstance(dim, DimRequest)
if isinstance(dim, DimRequest):
dim, dim_type = dim.dim, dim.dim_type
elif isinstance(name, NameRequest):
name, dim_type = name.name, name.dim_type
read_frames = (self.global_frame,) if dim_type != DimType.LOCAL else \
(self.current_frame,) + self.current_frame.parents + self.current_frame.iter_parents + (self.global_frame,)
# read dimension
for frame in read_frames:
name, dim, found = frame.read(name, dim)
if found:
break
# generate fresh name or dimension
if not found:
name, dim = self._gendim(NameRequest(name, dim_type), DimRequest(dim, dim_type))
write_frames = (self.global_frame,) if dim_type != DimType.LOCAL else \
(self.current_frame,) + (self.current_frame.parents if self.current_frame.keep else ())
# store the fresh dimension
for frame in write_frames:
frame.write(name, dim)
return name, dim
_DIM_STACK = DimStack() # only one global instance
#################################################
# Messengers that implement guts of enumeration
#################################################
class ReentrantMessenger(Messenger):
def __init__(self, fn=None):
self._ref_count = 0
super().__init__(fn)
# def __call__(self, fn):
# return functools.wraps(fn)(super().__call__(fn))
def __enter__(self):
self._ref_count += 1
if self._ref_count == 1:
super().__enter__()
return self
def __exit__(self, exc_type, exc_value, traceback):
self._ref_count -= 1
if self._ref_count == 0:
super().__exit__(exc_type, exc_value, traceback)
class DimStackCleanupMessenger(ReentrantMessenger):
def __init__(self, fn=None):
self._saved_dims = ()
return super().__init__(fn)
def __enter__(self):
if self._ref_count == 0 and _DIM_STACK.outermost is None:
_DIM_STACK.outermost = self
for name, dim in self._saved_dims:
_DIM_STACK.global_frame.write(name, dim)
self._saved_dims = ()
return super().__enter__()
def __exit__(self, *args, **kwargs):
if self._ref_count == 1 and _DIM_STACK.outermost is self:
_DIM_STACK.outermost = None
for name, dim in reversed(tuple(_DIM_STACK.global_frame.name_to_dim.items())):
self._saved_dims += (_DIM_STACK.global_frame.free(name, dim),)
return super().__exit__(*args, **kwargs)
class NamedMessenger(DimStackCleanupMessenger):
def process_message(self, msg):
if msg["type"] == "to_funsor":
self._pyro_to_funsor(msg)
elif msg["type"] == "to_data":
self._pyro_to_data(msg)
@staticmethod
def _get_name_to_dim(batch_names, name_to_dim=None, dim_type=DimType.LOCAL):
name_to_dim = OrderedDict() if name_to_dim is None else name_to_dim.copy()
# interpret all names/dims as requests since we only run this function once
for name in batch_names:
dim = name_to_dim.get(name, None)
name_to_dim[name] = dim if isinstance(dim, DimRequest) else DimRequest(dim, dim_type)
# read dimensions and allocate fresh dimensions as necessary
for name, dim_request in name_to_dim.items():
name_to_dim[name] = _DIM_STACK.request(name, dim_request)[1]
return name_to_dim
@classmethod # only depends on the global _DIM_STACK state, not self
def _pyro_to_data(cls, msg):
funsor_value, = msg["args"]
name_to_dim = msg["kwargs"].setdefault("name_to_dim", OrderedDict())
dim_type = msg["kwargs"].setdefault("dim_type", DimType.LOCAL)
batch_names = tuple(funsor_value.inputs.keys())
name_to_dim.update(cls._get_name_to_dim(batch_names, name_to_dim=name_to_dim, dim_type=dim_type))
msg["stop"] = True # only need to run this once per to_data call
@staticmethod
def _get_dim_to_name(batch_shape, dim_to_name=None, dim_type=DimType.LOCAL):
dim_to_name = OrderedDict() if dim_to_name is None else dim_to_name.copy()
batch_dim = len(batch_shape)
# interpret all names/dims as requests since we only run this function once
for dim in range(-batch_dim, 0):
name = dim_to_name.get(dim, None)
# the time dimension on the left sometimes necessitates empty dimensions appearing
# before they have been assigned a name
if batch_shape[dim] == 1 and name is None:
continue
dim_to_name[dim] = name if isinstance(name, NameRequest) else NameRequest(name, dim_type)
for dim, name_request in dim_to_name.items():
dim_to_name[dim] = _DIM_STACK.request(name_request, dim)[0]
return dim_to_name
@classmethod # only depends on the global _DIM_STACK state, not self
def _pyro_to_funsor(cls, msg):
if len(msg["args"]) == 2:
raw_value, output = msg["args"]
else:
raw_value = msg["args"][0]
output = msg["kwargs"].setdefault("output", None)
dim_to_name = msg["kwargs"].setdefault("dim_to_name", OrderedDict())
dim_type = msg["kwargs"].setdefault("dim_type", DimType.LOCAL)
event_dim = len(output.shape) if output else 0
try:
batch_shape = raw_value.batch_shape # TODO make make this more robust
except AttributeError:
batch_shape = raw_value.shape[:len(raw_value.shape) - event_dim]
dim_to_name.update(cls._get_dim_to_name(batch_shape, dim_to_name=dim_to_name, dim_type=dim_type))
msg["stop"] = True # only need to run this once per to_funsor call
class LocalNamedMessenger(NamedMessenger):
"""
Handler for converting to/from funsors consistent with Pyro's positional batch dimensions.
:param int history: The number of previous contexts visible from the
current context. Defaults to 1. If zero, this is similar to
:class:`pyro.plate`.
:param bool keep: If true, frames are replayable. This is important
when branching: if ``keep=True``, neighboring branches at the same
level can depend on each other; if ``keep=False``, neighboring branches
are independent (conditioned on their shared ancestors).
"""
def __init__(self, fn=None, history=1, keep=False):
self.history = history
self.keep = keep
self._iterable = None
self._saved_frames = []
self._iter_parents = ()
super().__init__(fn)
def generator(self, iterable):
self._iterable = iterable
return self
def _get_iter_parents(self, frame):
iter_parents = [frame]
frontier = (frame,)
while frontier:
frontier = sum([p.iter_parents for p in frontier], ())
iter_parents += frontier
return tuple(iter_parents)
def __iter__(self):
assert self._iterable is not None
self._iter_parents = self._get_iter_parents(_DIM_STACK.current_frame)
with ExitStack() as stack:
for value in self._iterable:
stack.enter_context(self)
yield value
def __enter__(self):
if self.keep and self._saved_frames:
saved_frame = self._saved_frames.pop()
name_to_dim, dim_to_name = saved_frame.name_to_dim, saved_frame.dim_to_name
else:
name_to_dim, dim_to_name = OrderedDict(), OrderedDict()
frame = StackFrame(
name_to_dim=name_to_dim, dim_to_name=dim_to_name,
parents=tuple(reversed(_DIM_STACK._stack[len(_DIM_STACK._stack) - self.history:])),
iter_parents=tuple(self._iter_parents),
keep=self.keep
)
_DIM_STACK.push(frame)
return super().__enter__()
def __exit__(self, *args, **kwargs):
if self.keep:
# don't keep around references to other frames
old_frame = _DIM_STACK.pop()
saved_frame = StackFrame(
name_to_dim=old_frame.name_to_dim, dim_to_name=old_frame.dim_to_name,
parents=(), iter_parents=(), keep=self.keep
)
self._saved_frames.append(saved_frame)
else:
_DIM_STACK.pop()
return super().__exit__(*args, **kwargs)
class GlobalNamedMessenger(NamedMessenger):
def __init__(self, fn=None):
self._saved_globals = ()
super().__init__(fn)
def __enter__(self):
if self._ref_count == 0:
for name, dim in self._saved_globals:
_DIM_STACK.global_frame.write(name, dim)
self._saved_globals = ()
return super().__enter__()
def __exit__(self, *args, **kwargs):
if self._ref_count == 1:
for name, dim in self._saved_globals:
_DIM_STACK.global_frame.free(name, dim)
return super().__exit__(*args, **kwargs)
def postprocess_message(self, msg):
if msg["type"] == "to_funsor":
self._pyro_post_to_funsor(msg)
elif msg["type"] == "to_data":
self._pyro_post_to_data(msg)
def _pyro_post_to_funsor(self, msg):
if msg["kwargs"]["dim_type"] in (DimType.GLOBAL, DimType.VISIBLE):
for name in msg["value"].inputs:
self._saved_globals += ((name, _DIM_STACK.global_frame.name_to_dim[name]),)
def _pyro_post_to_data(self, msg):
if msg["kwargs"]["dim_type"] in (DimType.GLOBAL, DimType.VISIBLE):
for name in msg["args"][0].inputs:
self._saved_globals += ((name, _DIM_STACK.global_frame.name_to_dim[name]),)
class BaseEnumMessenger(NamedMessenger):
"""
Handles first_available_dim management, enum effects should inherit from this
"""
def __init__(self, fn=None, first_available_dim=None):
assert first_available_dim is None or first_available_dim < 0, first_available_dim
self.first_available_dim = first_available_dim
super().__init__(fn)
def __enter__(self):
if self._ref_count == 0 and self.first_available_dim is not None:
self._prev_first_dim = _DIM_STACK.set_first_available_dim(self.first_available_dim)
return super().__enter__()
def __exit__(self, *args, **kwargs):
if self._ref_count == 1 and self.first_available_dim is not None:
_DIM_STACK.set_first_available_dim(self._prev_first_dim)
return super().__exit__(*args, **kwargs)
##########################################
# User-facing handler implementations
##########################################
[docs]class plate(GlobalNamedMessenger):
"""
An alternative implementation of :class:`numpyro.primitives.plate` primitive. Note
that only this version is compatible with enumeration.
There is also a context manager
:func:`~numpyro.contrib.funsor.infer_util.plate_to_enum_plate`
which converts `numpyro.plate` statements to this version.
:param str name: Name of the plate.
:param int size: Size of the plate.
:param int subsample_size: Optional argument denoting the size of the mini-batch.
This can be used to apply a scaling factor by inference algorithms. e.g.
when computing ELBO using a mini-batch.
:param int dim: Optional argument to specify which dimension in the tensor
is used as the plate dim. If `None` (default), the leftmost available dim
is allocated.
"""
def __init__(self, name, size, subsample_size=None, dim=None):
self.name = name
self.size = size
if dim is not None and dim >= 0:
raise ValueError('dim arg must be negative.')
self.dim, indices = OrigPlateMessenger._subsample(self.name, self.size, subsample_size, dim)
self.subsample_size = indices.shape[0]
self._indices = funsor.Tensor(
indices,
OrderedDict([(self.name, funsor.bint(self.subsample_size))]),
self.subsample_size
)
super(plate, self).__init__(None)
def __enter__(self):
super().__enter__() # do this first to take care of globals recycling
name_to_dim = OrderedDict([(self.name, self.dim)]) if self.dim is not None else OrderedDict()
indices = to_data(self._indices, name_to_dim=name_to_dim, dim_type=DimType.VISIBLE)
# extract the dimension allocated by to_data to match plate's current behavior
self.dim, self.indices = -len(indices.shape), indices.squeeze()
return self.indices
@staticmethod
def _get_batch_shape(cond_indep_stack):
n_dims = max(-f.dim for f in cond_indep_stack)
batch_shape = [1] * n_dims
for f in cond_indep_stack:
batch_shape[f.dim] = f.size
return tuple(batch_shape)
[docs] def process_message(self, msg):
if msg["type"] in ["to_funsor", "to_data"]:
return super().process_message(msg)
return OrigPlateMessenger.process_message(self, msg)
[docs] def postprocess_message(self, msg):
if msg["type"] in ["to_funsor", "to_data"]:
return super().postprocess_message(msg)
# NB: copied literally from original plate messenger, with self._indices is replaced
# by self.indices
if msg["type"] in ("subsample", "param") and self.dim is not None:
event_dim = msg["kwargs"].get("event_dim")
if event_dim is not None:
assert event_dim >= 0
dim = self.dim - event_dim
shape = jnp.shape(msg["value"])
if len(shape) >= -dim and shape[dim] != 1:
if shape[dim] != self.size:
if msg["type"] == "param":
statement = "numpyro.param({}, ..., event_dim={})".format(msg["name"], event_dim)
else:
statement = "numpyro.subsample(..., event_dim={})".format(event_dim)
raise ValueError(
"Inside numpyro.plate({}, {}, dim={}) invalid shape of {}: {}"
.format(self.name, self.size, self.dim, statement, shape))
if self.subsample_size < self.size:
value = msg["value"]
new_value = jnp.take(value, self.indices, dim)
msg["value"] = new_value
[docs]class enum(BaseEnumMessenger):
"""
Enumerates in parallel over discrete sample sites marked
``infer={"enumerate": "parallel"}``.
:param callable fn: Python callable with NumPyro primitives.
:param int first_available_dim: The first tensor dimension (counting
from the right) that is available for parallel enumeration. This
dimension and all dimensions left may be used internally by Pyro.
This should be a negative integer or None.
"""
[docs] def process_message(self, msg):
if msg["type"] != "sample" or \
msg.get("done", False) or msg["is_observed"] or msg["infer"].get("expand", False) or \
msg["infer"].get("enumerate") != "parallel" or (not msg["fn"].has_enumerate_support):
if msg["type"] == "control_flow":
msg["kwargs"]["enum"] = True
msg["kwargs"]["first_available_dim"] = self.first_available_dim
return super().process_message(msg)
if msg["infer"].get("num_samples", None) is not None:
raise NotImplementedError("TODO implement multiple sampling")
if msg["infer"].get("expand", False):
raise NotImplementedError("expand=True not implemented")
size = msg["fn"].enumerate_support(expand=False).shape[0]
raw_value = jnp.arange(0, size)
funsor_value = funsor.Tensor(
raw_value,
OrderedDict([(msg["name"], funsor.bint(size))]),
size
)
msg["value"] = to_data(funsor_value)
msg["done"] = True
[docs]class trace(OrigTraceMessenger):
"""
This version of :class:`~numpyro.handlers.trace` handler records
information necessary to do packing after execution.
Each sample site is annotated with a "dim_to_name" dictionary,
which can be passed directly to :func:`to_funsor`.
"""
[docs] def postprocess_message(self, msg):
if msg["type"] == "sample":
total_batch_shape = lax.broadcast_shapes(
tuple(msg["fn"].batch_shape),
jnp.shape(msg["value"])[:jnp.ndim(msg["value"]) - msg["fn"].event_dim]
)
msg["infer"]["dim_to_name"] = NamedMessenger._get_dim_to_name(total_batch_shape)
if msg["type"] in ("sample", "param"):
super().postprocess_message(msg)
[docs]def markov(fn=None, history=1, keep=False):
"""
Markov dependency declaration.
This is a statistical equivalent of a memory management arena.
:param callable fn: Python callable with NumPyro primitives.
:param int history: The number of previous contexts visible from the
current context. Defaults to 1. If zero, this is similar to
:class:`numpyro.primitives.plate`.
:param bool keep: If true, frames are replayable. This is important
when branching: if ``keep=True``, neighboring branches at the same
level can depend on each other; if ``keep=False``, neighboring branches
are independent (conditioned on their shared ancestors).
"""
if fn is not None and not callable(fn): # Used as a generator
return LocalNamedMessenger(fn=None, history=history, keep=keep).generator(iterable=fn)
return LocalNamedMessenger(fn, history=history, keep=keep)
####################
# New primitives
####################
[docs]def to_funsor(x, output=None, dim_to_name=None, dim_type=DimType.LOCAL):
"""
A primitive to convert a Python object to a :class:`~funsor.terms.Funsor`.
:param x: An object.
:param funsor.domains.Domain output: An optional output hint to uniquely
convert a data to a Funsor (e.g. when `x` is a string).
:param OrderedDict dim_to_name: An optional mapping from negative
batch dimensions to name strings.
:param int dim_type: Either 0, 1, or 2. This optional argument indicates
a dimension should be treated as 'local', 'global', or 'visible',
which can be used to interact with the global :class:`DimStack`.
:return: A Funsor equivalent to `x`.
:rtype: funsor.terms.Funsor
"""
dim_to_name = OrderedDict() if dim_to_name is None else dim_to_name
initial_msg = {
'type': 'to_funsor',
'fn': lambda x, output, dim_to_name, dim_type: funsor.to_funsor(x, output=output, dim_to_name=dim_to_name),
'args': (x,),
'kwargs': {"output": output, "dim_to_name": dim_to_name, "dim_type": dim_type},
'value': None,
'mask': None,
}
msg = apply_stack(initial_msg)
return msg['value']
[docs]def to_data(x, name_to_dim=None, dim_type=DimType.LOCAL):
"""
A primitive to extract a python object from a :class:`~funsor.terms.Funsor`.
:param ~funsor.terms.Funsor x: A funsor object
:param OrderedDict name_to_dim: An optional inputs hint which maps
dimension names from `x` to dimension positions of the returned value.
:param int dim_type: Either 0, 1, or 2. This optional argument indicates
a dimension should be treated as 'local', 'global', or 'visible',
which can be used to interact with the global :class:`DimStack`.
:return: A non-funsor equivalent to `x`.
"""
name_to_dim = OrderedDict() if name_to_dim is None else name_to_dim
initial_msg = {
'type': 'to_data',
'fn': lambda x, name_to_dim, dim_type: funsor.to_data(x, name_to_dim=name_to_dim),
'args': (x,),
'kwargs': {"name_to_dim": name_to_dim, "dim_type": dim_type},
'value': None,
'mask': None,
}
msg = apply_stack(initial_msg)
return msg['value']