Optimizers

Optimizer classes defined here are light wrappers over the corresponding optimizers sourced from jax.experimental.optimizers with an interface that is better suited for working with NumPyro inference algorithms.

Adam

class Adam(*args, **kwargs)[source]

Wrapper class for the JAX optimizer: adam()

eval_and_stable_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Like eval_and_update() but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.

Parameters:
  • fn – objective function.
  • state – current optimizer state.
Returns:

a pair of the output of objective function and the new optimizer state.

eval_and_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as Minimize, the update is performed by reevaluating the function multiple times to get optimal parameters.

Parameters:
  • fn – objective function.
  • state – current optimizer state.
Returns:

a pair of the output of objective function and the new optimizer state.

get_params(state: Tuple[int, _OptState]) → _Params

Get current parameter values.

Parameters:state – current optimizer state.
Returns:collection with current value for parameters.
init(params: _Params) → Tuple[int, _OptState]

Initialize the optimizer with parameters designated to be optimized.

Parameters:params – a collection of numpy arrays.
Returns:initial optimizer state.
update(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Gradient update for the optimizer.

Parameters:
  • g – gradient information for parameters.
  • state – current optimizer state.
Returns:

new optimizer state after the update.

Adagrad

class Adagrad(*args, **kwargs)[source]

Wrapper class for the JAX optimizer: adagrad()

eval_and_stable_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Like eval_and_update() but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.

Parameters:
  • fn – objective function.
  • state – current optimizer state.
Returns:

a pair of the output of objective function and the new optimizer state.

eval_and_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as Minimize, the update is performed by reevaluating the function multiple times to get optimal parameters.

Parameters:
  • fn – objective function.
  • state – current optimizer state.
Returns:

a pair of the output of objective function and the new optimizer state.

get_params(state: Tuple[int, _OptState]) → _Params

Get current parameter values.

Parameters:state – current optimizer state.
Returns:collection with current value for parameters.
init(params: _Params) → Tuple[int, _OptState]

Initialize the optimizer with parameters designated to be optimized.

Parameters:params – a collection of numpy arrays.
Returns:initial optimizer state.
update(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Gradient update for the optimizer.

Parameters:
  • g – gradient information for parameters.
  • state – current optimizer state.
Returns:

new optimizer state after the update.

ClippedAdam

class ClippedAdam(*args, clip_norm=10.0, **kwargs)[source]

Adam optimizer with gradient clipping.

Parameters:clip_norm (float) – All gradient values will be clipped between [-clip_norm, clip_norm].

Reference:

A Method for Stochastic Optimization, Diederik P. Kingma, Jimmy Ba https://arxiv.org/abs/1412.6980

eval_and_stable_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Like eval_and_update() but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.

Parameters:
  • fn – objective function.
  • state – current optimizer state.
Returns:

a pair of the output of objective function and the new optimizer state.

eval_and_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as Minimize, the update is performed by reevaluating the function multiple times to get optimal parameters.

Parameters:
  • fn – objective function.
  • state – current optimizer state.
Returns:

a pair of the output of objective function and the new optimizer state.

get_params(state: Tuple[int, _OptState]) → _Params

Get current parameter values.

Parameters:state – current optimizer state.
Returns:collection with current value for parameters.
init(params: _Params) → Tuple[int, _OptState]

Initialize the optimizer with parameters designated to be optimized.

Parameters:params – a collection of numpy arrays.
Returns:initial optimizer state.
update(g, state)[source]

Gradient update for the optimizer.

Parameters:
  • g – gradient information for parameters.
  • state – current optimizer state.
Returns:

new optimizer state after the update.

Minimize

class Minimize(method='BFGS', **kwargs)[source]

Wrapper class for the JAX minimizer: minimize().

Example:

>>> from numpy.testing import assert_allclose
>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.infer import SVI, Trace_ELBO
>>> from numpyro.infer.autoguide import AutoLaplaceApproximation

>>> def model(x, y):
...     a = numpyro.sample("a", dist.Normal(0, 1))
...     b = numpyro.sample("b", dist.Normal(0, 1))
...     with numpyro.plate("N", y.shape[0]):
...         numpyro.sample("obs", dist.Normal(a + b * x, 0.1), obs=y)

>>> x = jnp.linspace(0, 10, 100)
>>> y = 3 * x + 2
>>> optimizer = numpyro.optim.Minimize()
>>> guide = AutoLaplaceApproximation(model)
>>> svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
>>> init_state = svi.init(random.PRNGKey(0), x, y)
>>> optimal_state, loss = svi.update(init_state, x, y)
>>> params = svi.get_params(optimal_state)  # get guide's parameters
>>> quantiles = guide.quantiles(params, 0.5)  # get means of posterior samples
>>> assert_allclose(quantiles["a"], 2., atol=1e-3)
>>> assert_allclose(quantiles["b"], 3., atol=1e-3)
eval_and_stable_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Like eval_and_update() but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.

Parameters:
  • fn – objective function.
  • state – current optimizer state.
Returns:

a pair of the output of objective function and the new optimizer state.

eval_and_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState][source]

Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as Minimize, the update is performed by reevaluating the function multiple times to get optimal parameters.

Parameters:
  • fn – objective function.
  • state – current optimizer state.
Returns:

a pair of the output of objective function and the new optimizer state.

get_params(state: Tuple[int, _OptState]) → _Params

Get current parameter values.

Parameters:state – current optimizer state.
Returns:collection with current value for parameters.
init(params: _Params) → Tuple[int, _OptState]

Initialize the optimizer with parameters designated to be optimized.

Parameters:params – a collection of numpy arrays.
Returns:initial optimizer state.
update(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Gradient update for the optimizer.

Parameters:
  • g – gradient information for parameters.
  • state – current optimizer state.
Returns:

new optimizer state after the update.

Momentum

class Momentum(*args, **kwargs)[source]

Wrapper class for the JAX optimizer: momentum()

eval_and_stable_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Like eval_and_update() but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.

Parameters:
  • fn – objective function.
  • state – current optimizer state.
Returns:

a pair of the output of objective function and the new optimizer state.

eval_and_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as Minimize, the update is performed by reevaluating the function multiple times to get optimal parameters.

Parameters:
  • fn – objective function.
  • state – current optimizer state.
Returns:

a pair of the output of objective function and the new optimizer state.

get_params(state: Tuple[int, _OptState]) → _Params

Get current parameter values.

Parameters:state – current optimizer state.
Returns:collection with current value for parameters.
init(params: _Params) → Tuple[int, _OptState]

Initialize the optimizer with parameters designated to be optimized.

Parameters:params – a collection of numpy arrays.
Returns:initial optimizer state.
update(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Gradient update for the optimizer.

Parameters:
  • g – gradient information for parameters.
  • state – current optimizer state.
Returns:

new optimizer state after the update.

RMSProp

class RMSProp(*args, **kwargs)[source]

Wrapper class for the JAX optimizer: rmsprop()

eval_and_stable_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Like eval_and_update() but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.

Parameters:
  • fn – objective function.
  • state – current optimizer state.
Returns:

a pair of the output of objective function and the new optimizer state.

eval_and_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as Minimize, the update is performed by reevaluating the function multiple times to get optimal parameters.

Parameters:
  • fn – objective function.
  • state – current optimizer state.
Returns:

a pair of the output of objective function and the new optimizer state.

get_params(state: Tuple[int, _OptState]) → _Params

Get current parameter values.

Parameters:state – current optimizer state.
Returns:collection with current value for parameters.
init(params: _Params) → Tuple[int, _OptState]

Initialize the optimizer with parameters designated to be optimized.

Parameters:params – a collection of numpy arrays.
Returns:initial optimizer state.
update(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Gradient update for the optimizer.

Parameters:
  • g – gradient information for parameters.
  • state – current optimizer state.
Returns:

new optimizer state after the update.

RMSPropMomentum

class RMSPropMomentum(*args, **kwargs)[source]

Wrapper class for the JAX optimizer: rmsprop_momentum()

eval_and_stable_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Like eval_and_update() but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.

Parameters:
  • fn – objective function.
  • state – current optimizer state.
Returns:

a pair of the output of objective function and the new optimizer state.

eval_and_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as Minimize, the update is performed by reevaluating the function multiple times to get optimal parameters.

Parameters:
  • fn – objective function.
  • state – current optimizer state.
Returns:

a pair of the output of objective function and the new optimizer state.

get_params(state: Tuple[int, _OptState]) → _Params

Get current parameter values.

Parameters:state – current optimizer state.
Returns:collection with current value for parameters.
init(params: _Params) → Tuple[int, _OptState]

Initialize the optimizer with parameters designated to be optimized.

Parameters:params – a collection of numpy arrays.
Returns:initial optimizer state.
update(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Gradient update for the optimizer.

Parameters:
  • g – gradient information for parameters.
  • state – current optimizer state.
Returns:

new optimizer state after the update.

SGD

class SGD(*args, **kwargs)[source]

Wrapper class for the JAX optimizer: sgd()

eval_and_stable_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Like eval_and_update() but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.

Parameters:
  • fn – objective function.
  • state – current optimizer state.
Returns:

a pair of the output of objective function and the new optimizer state.

eval_and_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as Minimize, the update is performed by reevaluating the function multiple times to get optimal parameters.

Parameters:
  • fn – objective function.
  • state – current optimizer state.
Returns:

a pair of the output of objective function and the new optimizer state.

get_params(state: Tuple[int, _OptState]) → _Params

Get current parameter values.

Parameters:state – current optimizer state.
Returns:collection with current value for parameters.
init(params: _Params) → Tuple[int, _OptState]

Initialize the optimizer with parameters designated to be optimized.

Parameters:params – a collection of numpy arrays.
Returns:initial optimizer state.
update(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Gradient update for the optimizer.

Parameters:
  • g – gradient information for parameters.
  • state – current optimizer state.
Returns:

new optimizer state after the update.

SM3

class SM3(*args, **kwargs)[source]

Wrapper class for the JAX optimizer: sm3()

eval_and_stable_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Like eval_and_update() but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.

Parameters:
  • fn – objective function.
  • state – current optimizer state.
Returns:

a pair of the output of objective function and the new optimizer state.

eval_and_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as Minimize, the update is performed by reevaluating the function multiple times to get optimal parameters.

Parameters:
  • fn – objective function.
  • state – current optimizer state.
Returns:

a pair of the output of objective function and the new optimizer state.

get_params(state: Tuple[int, _OptState]) → _Params

Get current parameter values.

Parameters:state – current optimizer state.
Returns:collection with current value for parameters.
init(params: _Params) → Tuple[int, _OptState]

Initialize the optimizer with parameters designated to be optimized.

Parameters:params – a collection of numpy arrays.
Returns:initial optimizer state.
update(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Gradient update for the optimizer.

Parameters:
  • g – gradient information for parameters.
  • state – current optimizer state.
Returns:

new optimizer state after the update.

Optax support

optax_to_numpyro(transformation: optax._src.transform.GradientTransformation) → numpyro.optim._NumPyroOptim[source]

This function produces a numpyro.optim._NumPyroOptim instance from an optax.GradientTransformation so that it can be used with numpyro.infer.svi.SVI. It is a lightweight wrapper that recreates the (init_fn, update_fn, get_params_fn) interface defined by jax.experimental.optimizers.

Parameters:transformation – An optax.GradientTransformation instance to wrap.
Returns:An instance of numpyro.optim._NumPyroOptim wrapping the supplied Optax optimizer.