Effect Handlers

This provides a small set of effect handlers in NumPyro that are modeled after Pyro’s poutine module. For a tutorial on effect handlers more generally, readers are encouraged to read Poutine: A Guide to Programming with Effect Handlers in Pyro. These simple effect handlers can be composed together or new ones added to enable implementation of custom inference utilities and algorithms.

Example

As an example, we are using seed, trace and substitute handlers to define the log_likelihood function below. We first create a logistic regression model and sample from the posterior distribution over the regression parameters using MCMC(). The log_likelihood function uses effect handlers to run the model by substituting sample sites with values from the posterior distribution and computes the log density for a single data point. The log_predictive_density function computes the log likelihood for each draw from the joint posterior and aggregates the results for all the data points, but does so by using JAX’s auto-vectorize transform called vmap so that we do not need to loop over all the data points.

>>> import jax.numpy as jnp
>>> from jax import random, vmap
>>> from jax.scipy.special import logsumexp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro import handlers
>>> from numpyro.infer import MCMC, NUTS

>>> N, D = 3000, 3
>>> def logistic_regression(data, labels):
...     coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(D), jnp.ones(D)))
...     intercept = numpyro.sample('intercept', dist.Normal(0., 10.))
...     logits = jnp.sum(coefs * data + intercept, axis=-1)
...     return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

>>> data = random.normal(random.PRNGKey(0), (N, D))
>>> true_coefs = jnp.arange(1., D + 1.)
>>> logits = jnp.sum(true_coefs * data, axis=-1)
>>> labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

>>> num_warmup, num_samples = 1000, 1000
>>> mcmc = MCMC(NUTS(model=logistic_regression), num_warmup, num_samples)
>>> mcmc.run(random.PRNGKey(2), data, labels)  
sample: 100%|██████████| 1000/1000 [00:00<00:00, 1252.39it/s, 1 steps of size 5.83e-01. acc. prob=0.85]
>>> mcmc.print_summary()  


                   mean         sd       5.5%      94.5%      n_eff       Rhat
    coefs[0]       0.96       0.07       0.85       1.07     455.35       1.01
    coefs[1]       2.05       0.09       1.91       2.20     332.00       1.01
    coefs[2]       3.18       0.13       2.96       3.37     320.27       1.00
   intercept      -0.03       0.02      -0.06       0.00     402.53       1.00

>>> def log_likelihood(rng_key, params, model, *args, **kwargs):
...     model = handlers.substitute(handlers.seed(model, rng_key), params)
...     model_trace = handlers.trace(model).get_trace(*args, **kwargs)
...     obs_node = model_trace['obs']
...     return obs_node['fn'].log_prob(obs_node['value'])

>>> def log_predictive_density(rng_key, params, model, *args, **kwargs):
...     n = list(params.values())[0].shape[0]
...     log_lk_fn = vmap(lambda rng_key, params: log_likelihood(rng_key, params, model, *args, **kwargs))
...     log_lk_vals = log_lk_fn(random.split(rng_key, n), params)
...     return jnp.sum(logsumexp(log_lk_vals, 0) - jnp.log(n))

>>> print(log_predictive_density(random.PRNGKey(2), mcmc.get_samples(),
...       logistic_regression, data, labels))  
-874.89813

block

class block(fn=None, hide_fn=None, hide=None)[source]

Bases: numpyro.primitives.Messenger

Given a callable fn, return another callable that selectively hides primitive sites where hide_fn returns True from other effect handlers on the stack.

Parameters:
  • fn – Python callable with NumPyro primitives.
  • hide_fn – function which when given a dictionary containing site-level metadata returns whether it should be blocked.

Example:

>>> from jax import random
>>> import numpyro
>>> from numpyro.handlers import block, seed, trace
>>> import numpyro.distributions as dist

>>> def model():
...     a = numpyro.sample('a', dist.Normal(0., 1.))
...     return numpyro.sample('b', dist.Normal(a, 1.))

>>> model = seed(model, random.PRNGKey(0))
>>> block_all = block(model)
>>> block_a = block(model, lambda site: site['name'] == 'a')
>>> trace_block_all = trace(block_all).get_trace()
>>> assert not {'a', 'b'}.intersection(trace_block_all.keys())
>>> trace_block_a =  trace(block_a).get_trace()
>>> assert 'a' not in trace_block_a
>>> assert 'b' in trace_block_a
process_message(msg)[source]

condition

class condition(fn=None, data=None, condition_fn=None, param_map=None)[source]

Bases: numpyro.primitives.Messenger

Conditions unobserved sample sites to values from data or condition_fn. Similar to substitute except that it only affects sample sites and changes the is_observed property to True.

Parameters:
  • fn – Python callable with NumPyro primitives.
  • data (dict) – dictionary of numpy.ndarray values keyed by site names.
  • condition_fn – callable that takes in a site dict and returns a numpy array or None (in which case the handler has no side effect).

Example:

>>> from jax import random
>>> import numpyro
>>> from numpyro.handlers import condition, seed, substitute, trace
>>> import numpyro.distributions as dist

>>> def model():
...     numpyro.sample('a', dist.Normal(0., 1.))

>>> model = seed(model, random.PRNGKey(0))
>>> exec_trace = trace(condition(model, {'a': -1})).get_trace()
>>> assert exec_trace['a']['value'] == -1
>>> assert exec_trace['a']['is_observed']
process_message(msg)[source]

lift

class lift(fn=None, prior=None)[source]

Bases: numpyro.primitives.Messenger

Given a stochastic function with param calls and a prior distribution, create a stochastic function where all param calls are replaced by sampling from prior. Prior should be a distribution or a dict of names to distributions.

Consider the following NumPyro program:

>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.handlers import lift
>>>
>>> def model(x):
...     s = numpyro.param("s", 0.5)
...     z = numpyro.sample("z", dist.Normal(x, s))
...     return z ** 2
>>> lifted_model = lift(model, prior={"s": dist.Exponential(0.3)})

lift makes param statements behave like sample statements using the distributions in prior. In this example, site s will now behave as if it was replaced with s = numpyro.sample("s", dist.Exponential(0.3)).

Parameters:
  • fn – function whose parameters will be lifted to random values
  • prior – prior function in the form of a Distribution or a dict of Distributions
process_message(msg)[source]

mask

class mask(fn=None, mask_array=True)[source]

Bases: numpyro.primitives.Messenger

This messenger masks out some of the sample statements elementwise.

Parameters:mask_array – a DeviceArray with bool dtype for masking elementwise masking of sample sites.
process_message(msg)[source]

reparam

class reparam(fn=None, config=None)[source]

Bases: numpyro.primitives.Messenger

Reparametrizes each affected sample site into one or more auxiliary sample sites followed by a deterministic transformation [1].

To specify reparameterizers, pass a config dict or callable to the constructor. See the numpyro.infer.reparam module for available reparameterizers.

Note some reparameterizers can examine the *args,**kwargs inputs of functions they affect; these reparameterizers require using handlers.reparam as a decorator rather than as a context manager.

[1] Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)
“Automatic Reparameterisation of Probabilistic Programs” https://arxiv.org/pdf/1906.03028.pdf
Parameters:config (dict or callable) – Configuration, either a dict mapping site name to Reparam , or a function mapping site to Reparam or None.
process_message(msg)[source]

replay

class replay(fn=None, guide_trace=None)[source]

Bases: numpyro.primitives.Messenger

Given a callable fn and an execution trace guide_trace, return a callable which substitutes sample calls in fn with values from the corresponding site names in guide_trace.

Parameters:
  • fn – Python callable with NumPyro primitives.
  • guide_trace – an OrderedDict containing execution metadata.

Example

>>> from jax import random
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.handlers import replay, seed, trace

>>> def model():
...     numpyro.sample('a', dist.Normal(0., 1.))

>>> exec_trace = trace(seed(model, random.PRNGKey(0))).get_trace()
>>> print(exec_trace['a']['value'])  
-0.20584235
>>> replayed_trace = trace(replay(model, exec_trace)).get_trace()
>>> print(exec_trace['a']['value'])  
-0.20584235
>>> assert replayed_trace['a']['value'] == exec_trace['a']['value']
process_message(msg)[source]

scale

class scale(fn=None, scale=1.0)[source]

Bases: numpyro.primitives.Messenger

This messenger rescales the log probability score.

This is typically used for data subsampling or for stratified sampling of data (e.g. in fraud detection where negatives vastly outnumber positives).

Parameters:scale (float) – a positive scaling factor
process_message(msg)[source]

scope

class scope(fn=None, prefix='')[source]

Bases: numpyro.primitives.Messenger

This handler prepend a prefix followed by a / to the name of sample sites.

Example:

.. doctest::
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.handlers import scope, seed, trace
>>>
>>> def model():
...     with scope(prefix="a"):
...         with scope(prefix="b"):
...             return numpyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/b/x" in trace(seed(model, 0)).get_trace()
Parameters:
  • fn – Python callable with NumPyro primitives.
  • prefix (str) – a string to prepend to sample names
process_message(msg)[source]

seed

class seed(fn=None, rng_seed=None, rng=None)[source]

Bases: numpyro.primitives.Messenger

JAX uses a functional pseudo random number generator that requires passing in a seed PRNGKey() to every stochastic function. The seed handler allows us to initially seed a stochastic function with a PRNGKey(). Every call to the sample() primitive inside the function results in a splitting of this initial seed so that we use a fresh seed for each subsequent call without having to explicitly pass in a PRNGKey to each sample call.

Parameters:
  • fn – Python callable with NumPyro primitives.
  • rng_seed (int, jnp.ndarray scalar, or jax.random.PRNGKey) – a random number generator seed.

Note

Unlike in Pyro, numpyro.sample primitive cannot be used without wrapping it in seed handler since there is no global random state. As such, users need to use seed as a contextmanager to generate samples from distributions or as a decorator for their model callable (See below).

Example:

>>> from jax import random
>>> import numpyro
>>> import numpyro.handlers
>>> import numpyro.distributions as dist

>>> # as context manager
>>> with handlers.seed(rng_seed=1):
...     x = numpyro.sample('x', dist.Normal(0., 1.))

>>> def model():
...     return numpyro.sample('y', dist.Normal(0., 1.))

>>> # as function decorator (/modifier)
>>> y = handlers.seed(model, rng_seed=1)()
>>> assert x == y
process_message(msg)[source]

substitute

class substitute(fn=None, data=None, substitute_fn=None, param_map=None)[source]

Bases: numpyro.primitives.Messenger

Given a callable fn and a dict data keyed by site names (alternatively, a callable substitute_fn), return a callable which substitutes all primitive calls in fn with values from data whose key matches the site name. If the site name is not present in data, there is no side effect.

If a substitute_fn is provided, then the value at the site is replaced by the value returned from the call to substitute_fn for the given site.

Parameters:
  • fn – Python callable with NumPyro primitives.
  • data (dict) – dictionary of numpy.ndarray values keyed by site names.
  • substitute_fn – callable that takes in a site dict and returns a numpy array or None (in which case the handler has no side effect).

Example:

>>> from jax import random
>>> import numpyro
>>> from numpyro.handlers import seed, substitute, trace
>>> import numpyro.distributions as dist

>>> def model():
...     numpyro.sample('a', dist.Normal(0., 1.))

>>> model = seed(model, random.PRNGKey(0))
>>> exec_trace = trace(substitute(model, {'a': -1})).get_trace()
>>> assert exec_trace['a']['value'] == -1
process_message(msg)[source]

trace

class trace(fn=None)[source]

Bases: numpyro.primitives.Messenger

Returns a handler that records the inputs and outputs at primitive calls inside fn.

Example

>>> from jax import random
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.handlers import seed, trace
>>> import pprint as pp

>>> def model():
...     numpyro.sample('a', dist.Normal(0., 1.))

>>> exec_trace = trace(seed(model, random.PRNGKey(0))).get_trace()
>>> pp.pprint(exec_trace)  
OrderedDict([('a',
              {'args': (),
               'fn': <numpyro.distributions.continuous.Normal object at 0x7f9e689b1eb8>,
               'is_observed': False,
               'kwargs': {'rng_key': DeviceArray([0, 0], dtype=uint32)},
               'name': 'a',
               'type': 'sample',
               'value': DeviceArray(-0.20584235, dtype=float32)})])
postprocess_message(msg)[source]
get_trace(*args, **kwargs)[source]

Run the wrapped callable and return the recorded trace.

Parameters:
  • *args – arguments to the callable.
  • **kwargs – keyword arguments to the callable.
Returns:

OrderedDict containing the execution trace.