Source code for numpyro.util

import random
from collections import namedtuple
from contextlib import contextmanager

import numpy as onp
import tqdm

import jax.numpy as np
from jax import jit, lax, ops, vmap
from jax.flatten_util import ravel_pytree
from jax.tree_util import register_pytree_node

_DATA_TYPES = {}
_DISABLE_CONTROL_FLOW_PRIM = False


def set_rng_seed(rng_seed):
    random.seed(rng_seed)
    onp.random.seed(rng_seed)


# let JAX recognize _TreeInfo structure
# ref: https://github.com/google/jax/issues/446
# TODO: remove this when namedtuple is supported in JAX
def register_pytree(cls):
    if not getattr(cls, '_registered', False):
        register_pytree_node(
            cls,
            lambda xs: (tuple(xs), None),
            lambda _, xs: cls(*xs)
        )
    cls._registered = True


def laxtuple(name, fields):
    key = (name,) + tuple(fields)
    if key in _DATA_TYPES:
        return _DATA_TYPES[key]
    cls = namedtuple(name, fields)
    register_pytree(cls)
    cls.update = cls._replace
    _DATA_TYPES[key] = cls
    return cls


@contextmanager
def optional(condition, context_manager):
    """
    Optionally wrap inside `context_manager` if condition is `True`.
    """
    if condition:
        with context_manager:
            yield
    else:
        yield


@contextmanager
def control_flow_prims_disabled():
    global _DISABLE_CONTROL_FLOW_PRIM
    stored_flag = _DISABLE_CONTROL_FLOW_PRIM
    try:
        _DISABLE_CONTROL_FLOW_PRIM = True
        yield
    finally:
        _DISABLE_CONTROL_FLOW_PRIM = stored_flag


def cond(pred, true_operand, true_fun, false_operand, false_fun):
    if _DISABLE_CONTROL_FLOW_PRIM:
        if pred:
            return true_fun(true_operand)
        else:
            return false_fun(false_operand)
    else:
        return lax.cond(pred, true_operand, true_fun, false_operand, false_fun)


def while_loop(cond_fun, body_fun, init_val):
    if _DISABLE_CONTROL_FLOW_PRIM:
        val = init_val
        while cond_fun(val):
            val = body_fun(val)
        return val
    else:
        return lax.while_loop(cond_fun, body_fun, init_val)


def fori_loop(lower, upper, body_fun, init_val):
    if _DISABLE_CONTROL_FLOW_PRIM:
        val = init_val
        for i in range(int(lower), int(upper)):
            val = body_fun(i, val)
        return val
    else:
        return lax.fori_loop(lower, upper, body_fun, init_val)


def identity(x):
    return x


[docs]def fori_collect(n, body_fun, init_val, transform=identity, progbar=True, **progbar_opts): """ This looping construct works like :func:`~jax.lax.fori_loop` but with the additional effect of collecting values from the loop body. In addition, this allows for post-processing of these samples via `transform`, and progress bar updates. Note that, in some cases, `progbar=False` can be faster, when collecting a lot of samples. Refer to example usage in :func:`~numpyro.mcmc.hmc`. :param int n: number of times to run the loop body. :param body_fun: a callable that takes a collection of `np.ndarray` and returns a collection with the same shape and `dtype`. :param init_val: initial value to pass as argument to `body_fun`. Can be any Python collection type containing `np.ndarray` objects. :param transform: A callable :param progbar: whether to post progress bar updates. :param `**progbar_opts`: optional additional progress bar arguments. A `diagnostics_fn` can be supplied which when passed the current value from `body_fun` returns a string that is used to update the progress bar postfix. Also a `progbar_desc` keyword argument can be supplied which is used to label the progress bar. :return: collection with the same type as `init_val` with values collected along the leading axis of `np.ndarray` objects. """ init_val_flat, unravel_fn = ravel_pytree(transform(init_val)) ravel_fn = lambda x: ravel_pytree(transform(x))[0] # noqa: E731 if not progbar: collection = np.zeros((n,) + init_val_flat.shape, dtype=init_val_flat.dtype) def _body_fn(i, vals): val, collection = vals val = body_fun(val) collection = ops.index_update(collection, i, ravel_fn(val)) return val, collection _, collection = jit(lax.fori_loop, static_argnums=(2,))(0, n, _body_fn, (init_val, collection)) else: diagnostics_fn = progbar_opts.pop('diagnostics_fn', None) progbar_desc = progbar_opts.pop('progbar_desc', '') collection = [] val = init_val with tqdm.trange(n, desc=progbar_desc) as t: for _ in t: val = body_fun(val) collection.append(jit(ravel_fn)(val)) if diagnostics_fn: # TODO: set refresh=True when its performance issue is resolved t.set_postfix_str(diagnostics_fn(val), refresh=False) # XXX: jax.numpy.stack/concatenate is currently so slow collection = onp.stack(collection) return vmap(unravel_fn)(collection)
def copy_docs_from(source_class, full_text=False): """ Decorator to copy class and method docs from source to destin class. """ def decorator(destin_class): # This works only in python 3.3+: # if not destin_class.__doc__: # destin_class.__doc__ = source_class.__doc__ for name in dir(destin_class): if name.startswith('_'): continue destin_attr = getattr(destin_class, name) destin_attr = getattr(destin_attr, '__func__', destin_attr) source_attr = getattr(source_class, name, None) source_doc = getattr(source_attr, '__doc__', None) if source_doc and not getattr(destin_attr, '__doc__', None): if full_text or source_doc.startswith('See '): destin_doc = source_doc else: destin_doc = 'See :meth:`{}.{}.{}`'.format( source_class.__module__, source_class.__name__, name) if isinstance(destin_attr, property): # Set docs for object properties. # Since __doc__ is read-only, we need to reset the property # with the updated doc. updated_property = property(destin_attr.fget, destin_attr.fset, destin_attr.fdel, destin_doc) setattr(destin_class, name, updated_property) else: destin_attr.__doc__ = destin_doc return destin_class return decorator