• Values, including both state and action-values;

  • Values for Non-linear generalizations of the Bellman equations.

  • Return Distributions, aka distributional value functions;

  • General Value Functions, for cumulants other than the main reward;

  • Policies, via policy-gradients in both continuous and discrete action spaces.

Value Learning

categorical_double_q_learning(q_atoms_tm1, …)

Implements double Q-learning for categorical Q distributions.

categorical_l2_project(z_p, probs, z_q)

Projects a categorical distribution (z_p, p) onto a different support z_q.

categorical_q_learning(q_atoms_tm1, …[, …])

Implements Q-learning for categorical Q distributions.

categorical_td_learning(v_atoms_tm1, …[, …])

Implements TD-learning for categorical value distributions.

discounted_returns(r_t, discount_t, v_t[, …])

Calculates a discounted return from a trajectory.

double_q_learning(q_tm1, a_tm1, r_t, …[, …])

Calculates the double Q-learning temporal difference error.

expected_sarsa(q_tm1, a_tm1, r_t, …[, …])

Calculates the expected SARSA (SARSE) temporal difference error.

general_off_policy_returns_from_action_values(…)

Calculates targets for various off-policy correction algorithms.

general_off_policy_returns_from_q_and_v(q_t, …)

Calculates targets for various off-policy evaluation algorithms.

lambda_returns(r_t, discount_t, v_t[, …])

Estimates a multistep truncated lambda return from a trajectory.

leaky_vtrace(v_tm1, v_t, r_t, discount_t, …)

Calculates Leaky V-Trace errors from importance weights.

leaky_vtrace_td_error_and_advantage(v_tm1, …)

Calculates Leaky V-Trace errors and PG advantage from importance weights.

n_step_bootstrapped_returns(r_t, discount_t, …)

Computes strided n-step bootstrapped return targets over a sequence.

persistent_q_learning(q_tm1, a_tm1, r_t, …)

Calculates the persistent Q-learning temporal difference error.

q_lambda(q_tm1, a_tm1, r_t, discount_t, q_t, …)

Calculates Peng’s or Watkins’ Q(lambda) temporal difference error.

q_learning(q_tm1, a_tm1, r_t, discount_t, q_t)

Calculates the Q-learning temporal difference error.

quantile_expected_sarsa(dist_q_tm1, …[, …])

Implements Expected SARSA for quantile-valued Q distributions.

quantile_q_learning(dist_q_tm1, tau_q_tm1, …)

Implements Q-learning for quantile-valued Q distributions.

qv_learning(q_tm1, a_tm1, r_t, discount_t, v_t)

Calculates the QV-learning temporal difference error.

qv_max(v_tm1, r_t, discount_t, q_t[, …])

Calculates the QVMAX temporal difference error.

retrace(q_tm1, q_t, a_tm1, a_t, r_t, …[, …])

Calculates Retrace errors.

retrace_continuous(q_tm1, q_t, v_t, r_t, …)

Retrace continuous.

sarsa(q_tm1, a_tm1, r_t, discount_t, q_t, a_t)

Calculates the SARSA temporal difference error.

sarsa_lambda(q_tm1, a_tm1, r_t, discount_t, …)

Calculates the SARSA(lambda) temporal difference error.

td_lambda(v_tm1, r_t, discount_t, v_t, lambda_)

Calculates the TD(lambda) temporal difference error.

td_learning(v_tm1, r_t, discount_t, v_t[, …])

Calculates the TD-learning temporal difference error.

transformed_general_off_policy_returns_from_action_values(…)

Calculates targets for various off-policy correction algorithms.

transformed_lambda_returns(r_t, discount_t, v_t)

Estimates a multistep truncated lambda return from a trajectory.

transformed_n_step_q_learning(q_tm1, a_tm1, …)

Calculates transformed n-step TD errors.

transformed_n_step_returns(r_t, discount_t, …)

Computes strided n-step bootstrapped return targets over a sequence.

transformed_q_lambda(q_tm1, a_tm1, r_t, …)

Calculates Peng’s or Watkins’ Q(lambda) temporal difference error.

transformed_retrace(q_tm1, q_t, a_tm1, a_t, …)

Calculates transformed Retrace errors.

vtrace(v_tm1, v_t, r_t, discount_t, rho_tm1)

Calculates V-Trace errors from importance weights.

vtrace_td_error_and_advantage(v_tm1, v_t, …)

Calculates V-Trace errors and PG advantage from importance weights.

Categorical Double Q Learning

rlax.categorical_double_q_learning(q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t, q_logits_t, q_t_selector, stop_target_gradients=True)[source]

Implements double Q-learning for categorical Q distributions.

See “A Distributional Perspective on Reinforcement Learning”, by

Bellemere, Dabney and Munos (https://arxiv.org/pdf/1707.06887.pdf)

and “Double Q-learning” by van Hasselt. (https://papers.nips.cc/paper/3964-double-q-learning.pdf).

Parameters
  • q_atoms_tm1 (Array) – atoms of Q distribution at time t-1.

  • q_logits_tm1 (Array) – logits of Q distribution at time t-1.

  • a_tm1 (Numeric) – action index at time t-1.

  • r_t (Numeric) – reward at time t.

  • discount_t (Numeric) – discount at time t.

  • q_atoms_t (Array) – atoms of Q distribution at time t.

  • q_logits_t (Array) – logits of Q distribution at time t.

  • q_t_selector (Array) – selector Q-values at time t.

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Numeric

Returns

Categorical double Q-learning loss (i.e. temporal difference error).

Categorical L2 Project

rlax.categorical_l2_project(z_p, probs, z_q)[source]

Projects a categorical distribution (z_p, p) onto a different support z_q.

The projection step minimizes an L2-metric over the cumulative distribution functions (CDFs) of the source and target distributions.

Let kq be len(z_q) and kp be len(z_p). This projection works for any support z_q, in particular kq need not be equal to kp.

See “A Distributional Perspective on RL” by Bellemare et al. (https://arxiv.org/abs/1707.06887).

Parameters
  • z_p (Array) – support of distribution p.

  • probs (Array) – probability values.

  • z_q (Array) – support to project distribution (z_p, probs) onto.

Return type

Array

Returns

Projection of (z_p, p) onto support z_q under Cramer distance.

Categorical Q Learning

rlax.categorical_q_learning(q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t, q_logits_t, stop_target_gradients=True)[source]

Implements Q-learning for categorical Q distributions.

See “A Distributional Perspective on Reinforcement Learning”, by

Bellemere, Dabney and Munos (https://arxiv.org/pdf/1707.06887.pdf).

Parameters
  • q_atoms_tm1 (Array) – atoms of Q distribution at time t-1.

  • q_logits_tm1 (Array) – logits of Q distribution at time t-1.

  • a_tm1 (Numeric) – action index at time t-1.

  • r_t (Numeric) – reward at time t.

  • discount_t (Numeric) – discount at time t.

  • q_atoms_t (Array) – atoms of Q distribution at time t.

  • q_logits_t (Array) – logits of Q distribution at time t.

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Numeric

Returns

Categorical Q-learning loss (i.e. temporal difference error).

Categorical TD Learning

rlax.categorical_td_learning(v_atoms_tm1, v_logits_tm1, r_t, discount_t, v_atoms_t, v_logits_t, stop_target_gradients=True)[source]

Implements TD-learning for categorical value distributions.

See “A Distributional Perspective on Reinforcement Learning”, by

Bellemere, Dabney and Munos (https://arxiv.org/pdf/1707.06887.pdf).

Parameters
  • v_atoms_tm1 (Array) – atoms of V distribution at time t-1.

  • v_logits_tm1 (Array) – logits of V distribution at time t-1.

  • r_t (Numeric) – reward at time t.

  • discount_t (Numeric) – discount at time t.

  • v_atoms_t (Array) – atoms of V distribution at time t.

  • v_logits_t (Array) – logits of V distribution at time t.

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Numeric

Returns

Categorical Q learning loss (i.e. temporal difference error).

Discounted Returns

rlax.discounted_returns(r_t, discount_t, v_t, stop_target_gradients=False)[source]

Calculates a discounted return from a trajectory.

The returns are computed recursively, from G_{T-1} to G_0, according to:

Gₜ = rₜ₊₁ + γₜ₊₁ Gₜ₊₁.

See “Reinforcement Learning: An Introduction” by Sutton and Barto. (http://incompleteideas.net/sutton/book/ebook/node61.html).

Parameters
  • r_t (Array) – reward sequence at time t.

  • discount_t (Array) – discount sequence at time t.

  • v_t (Array) – value sequence or scalar at time t.

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Array

Returns

Discounted returns.

Double Q Learning

rlax.double_q_learning(q_tm1, a_tm1, r_t, discount_t, q_t_value, q_t_selector, stop_target_gradients=True)[source]

Calculates the double Q-learning temporal difference error.

See “Double Q-learning” by van Hasselt. (https://papers.nips.cc/paper/3964-double-q-learning.pdf).

Parameters
  • q_tm1 (Array) – Q-values at time t-1.

  • a_tm1 (Numeric) – action index at time t-1.

  • r_t (Numeric) – reward at time t.

  • discount_t (Numeric) – discount at time t.

  • q_t_value (Array) – Q-values at time t.

  • q_t_selector (Array) – selector Q-values at time t.

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Numeric

Returns

Double Q-learning temporal difference error.

Expected SARSA

rlax.expected_sarsa(q_tm1, a_tm1, r_t, discount_t, q_t, probs_a_t, stop_target_gradients=True)[source]

Calculates the expected SARSA (SARSE) temporal difference error.

See “A Theoretical and Empirical Analysis of Expected Sarsa” by Seijen, van Hasselt, Whiteson et al. (http://www.cs.ox.ac.uk/people/shimon.whiteson/pubs/vanseijenadprl09.pdf).

Parameters
  • q_tm1 (Array) – Q-values at time t-1.

  • a_tm1 (Numeric) – action index at time t-1.

  • r_t (Numeric) – reward at time t.

  • discount_t (Numeric) – discount at time t.

  • q_t (Array) – Q-values at time t.

  • probs_a_t (Array) – action probabilities at time t.

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Numeric

Returns

Expected SARSA temporal difference error.

General Off Policy Returns From Action Values

rlax.general_off_policy_returns_from_action_values(q_t, a_t, r_t, discount_t, c_t, pi_t, stop_target_gradients=False)[source]

Calculates targets for various off-policy correction algorithms.

Given a window of experience of length K, generated by a behaviour policy μ, for each time-step t we can estimate the return G_t from that step onwards, under some target policy π, using the rewards in the trajectory, the actions selected by μ and the action-values under π, according to equation:

Gₜ = rₜ₊₁ + γₜ₊₁ * (E[q(aₜ₊₁)] - cₜ * q(aₜ₊₁) + cₜ * Gₜ₊₁),

where, depending on the choice of c_t, the algorithm implements:

Importance Sampling c_t = π(x_t, a_t) / μ(x_t, a_t), Harutyunyan’s et al. Q(lambda) c_t = λ, Precup’s et al. Tree-Backup c_t = π(x_t, a_t), Munos’ et al. Retrace c_t = λ min(1, π(x_t, a_t) / μ(x_t, a_t)).

See “Safe and Efficient Off-Policy Reinforcement Learning” by Munos et al. (https://arxiv.org/abs/1606.02647).

Parameters
  • q_t (Array) – Q-values at times [1, …, K - 1].

  • a_t (Array) – action index at times [1, …, K - 1].

  • r_t (Array) – reward at times [1, …, K - 1].

  • discount_t (Array) – discount at times [1, …, K - 1].

  • c_t (Array) – importance weights at times [1, …, K - 1].

  • pi_t (Array) – target policy probs at times [1, …, K - 1].

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Array

Returns

Off-policy estimates of the generalized returns from states visited at times [0, …, K - 1].

General Off Policy Returns From Q and V

rlax.general_off_policy_returns_from_q_and_v(q_t, v_t, r_t, discount_t, c_t, stop_target_gradients=False)[source]

Calculates targets for various off-policy evaluation algorithms.

Given a window of experience of length K+1, generated by a behaviour policy μ, for each time-step t we can estimate the return G_t from that step onwards, under some target policy π, using the rewards in the trajectory, the values under π of states and actions selected by μ, according to equation:

Gₜ = rₜ₊₁ + γₜ₊₁ * (vₜ₊₁ - cₜ₊₁ * q(aₜ₊₁) + cₜ₊₁* Gₜ₊₁),

where, depending on the choice of c_t, the algorithm implements:

Importance Sampling c_t = π(x_t, a_t) / μ(x_t, a_t), Harutyunyan’s et al. Q(lambda) c_t = λ, Precup’s et al. Tree-Backup c_t = π(x_t, a_t), Munos’ et al. Retrace c_t = λ min(1, π(x_t, a_t) / μ(x_t, a_t)).

See “Safe and Efficient Off-Policy Reinforcement Learning” by Munos et al. (https://arxiv.org/abs/1606.02647).

Parameters
  • q_t (Array) – Q-values under π of actions executed by μ at times [1, …, K - 1].

  • v_t (Array) – Values under π at times [1, …, K].

  • r_t (Array) – rewards at times [1, …, K].

  • discount_t (Array) – discounts at times [1, …, K].

  • c_t (Array) – weights at times [1, …, K - 1].

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Array

Returns

Off-policy estimates of the generalized returns from states visited at times [0, …, K - 1].

Lambda Returns

rlax.lambda_returns(r_t, discount_t, v_t, lambda_=1.0, stop_target_gradients=False)[source]

Estimates a multistep truncated lambda return from a trajectory.

Given a a trajectory of length T+1, generated under some policy π, for each time-step t we can estimate a target return G_t, by combining rewards, discounts, and state values, according to a mixing parameter lambda.

The parameter lambda_ mixes the different multi-step bootstrapped returns, corresponding to accumulating k rewards and then bootstrapping using v_t.

rₜ₊₁ + γₜ₊₁ vₜ₊₁ rₜ₊₁ + γₜ₊₁ rₜ₊₂ + γₜ₊₁ γₜ₊₂ vₜ₊₂ rₜ₊₁ + γₜ₊₁ rₜ₊₂ + γₜ₊₁ γₜ₊₂ rₜ₊₂ + γₜ₊₁ γₜ₊₂ γₜ₊₃ vₜ₊₃

The returns are computed recursively, from G_{T-1} to G_0, according to:

Gₜ = rₜ₊₁ + γₜ₊₁ [(1 - λₜ₊₁) vₜ₊₁ + λₜ₊₁ Gₜ₊₁].

In the on-policy case, we estimate a return target G_t for the same policy π that was used to generate the trajectory. In this setting the parameter lambda_ is typically a fixed scalar factor. Depending on how values v_t are computed, this function can be used to construct targets for different multistep reinforcement learning updates:

TD(λ): v_t contains the state value estimates for each state under π. Q(λ): v_t = max(q_t, axis=-1), where q_t estimates the action values. Sarsa(λ): v_t = q_t[…, a_t], where q_t estimates the action values.

In the off-policy case, the mixing factor is a function of state, and different definitions of lambda implement different off-policy corrections:

Per-decision importance sampling: λₜ = λ ρₜ = λ [π(aₜ|sₜ) / μ(aₜ|sₜ)] V-trace, as instantiated in IMPALA: λₜ = min(1, ρₜ)

Note that the second option is equivalent to applying per-decision importance sampling, but using an adaptive λ(ρₜ) = min(1/ρₜ, 1), such that the effective bootstrap parameter at time t becomes λₜ = λ(ρₜ) * ρₜ = min(1, ρₜ). This is the interpretation used in the ABQ(ζ) algorithm (Mahmood 2017).

Of course this can be augmented to include an additional factor λ. For instance we could use V-trace with a fixed additional parameter λ = 0.9, by setting λₜ = 0.9 * min(1, ρₜ) or, alternatively (but not equivalently), λₜ = min(0.9, ρₜ).

Estimated return are then often used to define a td error, e.g.: ρₜ(Gₜ - vₜ).

See “Reinforcement Learning: An Introduction” by Sutton and Barto. (http://incompleteideas.net/sutton/book/ebook/node74.html).

Parameters
  • r_t (Array) – sequence of rewards rₜ for timesteps t in [1, T].

  • discount_t (Array) – sequence of discounts γₜ for timesteps t in [1, T].

  • v_t (Array) – sequence of state values estimates under π for timesteps t in [1, T].

  • lambda – mixing parameter; a scalar or a vector for timesteps t in [1, T].

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Array

Returns

Multistep lambda returns.

Leaky VTrace

rlax.leaky_vtrace(v_tm1, v_t, r_t, discount_t, rho_tm1, alpha_=1.0, lambda_=1.0, clip_rho_threshold=1.0, stop_target_gradients=True)[source]

Calculates Leaky V-Trace errors from importance weights.

Leaky-Vtrace is a combination of Importance sampling and V-trace, where the degree of mixing is controlled by a scalar alpha (that may be meta-learnt).

See “Self-Tuning Deep Reinforcement Learning” by Zahavy et al. (https://arxiv.org/abs/2002.12928)

Parameters
  • v_tm1 (Array) – values at time t-1.

  • v_t (Array) – values at time t.

  • r_t (Array) – reward at time t.

  • discount_t (Array) – discount at time t.

  • rho_tm1 (Array) – importance weights at time t-1.

  • alpha – mixing parameter for Importance Sampling and V-trace.

  • lambda – scalar mixing parameter lambda.

  • clip_rho_threshold (float) – clip threshold for importance weights.

  • stop_target_gradients (bool) – whether or not to apply stop gradient to targets.

Returns

Leaky V-Trace error.

N Step Bootstrapped Returns

rlax.n_step_bootstrapped_returns(r_t, discount_t, v_t, n, lambda_t=1.0, stop_target_gradients=False)[source]

Computes strided n-step bootstrapped return targets over a sequence.

The returns are computed according to the below equation iterated n times:

Gₜ = rₜ₊₁ + γₜ₊₁ [(1 - λₜ₊₁) vₜ₊₁ + λₜ₊₁ Gₜ₊₁].

When lambda_t == 1. (default), this reduces to

Gₜ = rₜ₊₁ + γₜ₊₁ * (rₜ₊₂ + γₜ₊₂ * (… * (rₜ₊ₙ + γₜ₊ₙ * vₜ₊ₙ ))).

Parameters
  • r_t (Array) – rewards at times [1, …, T].

  • discount_t (Array) – discounts at times [1, …, T].

  • v_t (Array) – state or state-action values to bootstrap from at time [1, …., T].

  • n (int) – number of steps over which to accumulate reward before bootstrapping.

  • lambda_t (Numeric) – lambdas at times [1, …, T]. Shape is [], or [T-1].

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Array

Returns

estimated bootstrapped returns at times [0, …., T-1]

Leaky VTrace TD Error and Advantage

rlax.leaky_vtrace_td_error_and_advantage(v_tm1, v_t, r_t, discount_t, rho_tm1, alpha=1.0, lambda_=1.0, clip_rho_threshold=1.0, clip_pg_rho_threshold=1.0, stop_target_gradients=True)[source]

Calculates Leaky V-Trace errors and PG advantage from importance weights.

This functions computes the Leaky V-Trace TD-errors and policy gradient Advantage terms as used by the IMPALA distributed actor-critic agent.

Leaky-Vtrace is a combination of Importance sampling and V-trace, where the degree of mixing is controlled by a scalar alpha (that may be meta-learnt).

See “Self-Tuning Deep Reinforcement Learning” by Zahavy et al. (https://arxiv.org/abs/2002.12928) and “IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor Learner Architectures” by Espeholt et al. (https://arxiv.org/abs/1802.01561)

Parameters
  • v_tm1 (chex.Array) – values at time t-1.

  • v_t (chex.Array) – values at time t.

  • r_t (chex.Array) – reward at time t.

  • discount_t (chex.Array) – discount at time t.

  • rho_tm1 (chex.Array) – importance weights at time t-1.

  • alpha (float) – mixing the clipped importance sampling weights with unclipped ones.

  • lambda – scalar mixing parameter lambda.

  • clip_rho_threshold (float) – clip threshold for importance ratios.

  • clip_pg_rho_threshold (float) – clip threshold for policy gradient importance ratios.

  • stop_target_gradients (bool) – whether or not to apply stop gradient to targets.

Return type

VTraceOutput

Returns

a tuple of V-Trace error, policy gradient advantage, and estimated Q-values.

Persistent Q Learning

rlax.persistent_q_learning(q_tm1, a_tm1, r_t, discount_t, q_t, action_gap_scale, stop_target_gradients=True)[source]

Calculates the persistent Q-learning temporal difference error.

See “Increasing the Action Gap: New Operators for Reinforcement Learning” by Bellemare, Ostrovski, Guez et al. (https://arxiv.org/abs/1512.04860).

Parameters
  • q_tm1 (Array) – Q-values at time t-1.

  • a_tm1 (Numeric) – action index at time t-1.

  • r_t (Numeric) – reward at time t.

  • discount_t (Numeric) – discount at time t.

  • q_t (Array) – Q-values at time t.

  • action_gap_scale (float) – coefficient in [0, 1] for scaling the action gap term.

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Numeric

Returns

Persistent Q-learning temporal difference error.

Q-Lambda

rlax.q_lambda(q_tm1, a_tm1, r_t, discount_t, q_t, lambda_, stop_target_gradients=True)[source]

Calculates Peng’s or Watkins’ Q(lambda) temporal difference error.

See “Reinforcement Learning: An Introduction” by Sutton and Barto. (http://incompleteideas.net/book/ebook/node78.html).

Parameters
  • q_tm1 (Array) – sequence of Q-values at time t-1.

  • a_tm1 (Array) – sequence of action indices at time t-1.

  • r_t (Array) – sequence of rewards at time t.

  • discount_t (Array) – sequence of discounts at time t.

  • q_t (Array) – sequence of Q-values at time t.

  • lambda – mixing parameter lambda, either a scalar (e.g. Peng’s Q(lambda)) or a sequence (e.g. Watkin’s Q(lambda)).

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Array

Returns

Q(lambda) temporal difference error.

Q Learning

rlax.q_learning(q_tm1, a_tm1, r_t, discount_t, q_t, stop_target_gradients=True)[source]

Calculates the Q-learning temporal difference error.

See “Reinforcement Learning: An Introduction” by Sutton and Barto. (http://incompleteideas.net/book/ebook/node65.html).

Parameters
  • q_tm1 (Array) – Q-values at time t-1.

  • a_tm1 (Numeric) – action index at time t-1.

  • r_t (Numeric) – reward at time t.

  • discount_t (Numeric) – discount at time t.

  • q_t (Array) – Q-values at time t.

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Numeric

Returns

Q-learning temporal difference error.

Quantile Expected Sarsa

rlax.quantile_expected_sarsa(dist_q_tm1, tau_q_tm1, a_tm1, r_t, discount_t, dist_q_t, probs_a_t, huber_param=0.0, stop_target_gradients=True)[source]

Implements Expected SARSA for quantile-valued Q distributions.

Parameters
  • dist_q_tm1 (Array) – Q distribution at time t-1.

  • tau_q_tm1 (Array) – Q distribution probability thresholds.

  • a_tm1 (Numeric) – action index at time t-1.

  • r_t (Numeric) – reward at time t.

  • discount_t (Numeric) – discount at time t.

  • dist_q_t (Array) – target Q distribution at time t.

  • probs_a_t (Array) – action probabilities at time t.

  • huber_param (float) – Huber loss parameter, defaults to 0 (no Huber loss).

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Numeric

Returns

Quantile regression Expected SARSA learning loss.

Quantile Q Learning

rlax.quantile_q_learning(dist_q_tm1, tau_q_tm1, a_tm1, r_t, discount_t, dist_q_t_selector, dist_q_t, huber_param=0.0, stop_target_gradients=True)[source]

Implements Q-learning for quantile-valued Q distributions.

See “Distributional Reinforcement Learning with Quantile Regression” by Dabney et al. (https://arxiv.org/abs/1710.10044).

Parameters
  • dist_q_tm1 (Array) – Q distribution at time t-1.

  • tau_q_tm1 (Array) – Q distribution probability thresholds.

  • a_tm1 (Numeric) – action index at time t-1.

  • r_t (Numeric) – reward at time t.

  • discount_t (Numeric) – discount at time t.

  • dist_q_t_selector (Array) – Q distribution at time t for selecting greedy action in target policy. This is separate from dist_q_t as in Double Q-Learning, but can be computed with the target network and a separate set of samples.

  • dist_q_t (Array) – target Q distribution at time t.

  • huber_param (float) – Huber loss parameter, defaults to 0 (no Huber loss).

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Numeric

Returns

Quantile regression Q learning loss.

QV Learning

rlax.qv_learning(q_tm1, a_tm1, r_t, discount_t, v_t, stop_target_gradients=True)[source]

Calculates the QV-learning temporal difference error.

See “Two Novel On-policy Reinforcement Learning Algorithms based on TD(lambda)-methods” by Wiering and van Hasselt (https://ieeexplore.ieee.org/abstract/document/4220845.)

Parameters
  • q_tm1 (Array) – Q-values at time t-1.

  • a_tm1 (Numeric) – action index at time t-1.

  • r_t (Numeric) – reward at time t.

  • discount_t (Numeric) – discount at time t.

  • v_t (Numeric) – state values at time t.

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Numeric

Returns

QV-learning temporal difference error.

QV Max

rlax.qv_max(v_tm1, r_t, discount_t, q_t, stop_target_gradients=True)[source]

Calculates the QVMAX temporal difference error.

See “The QV Family Compared to Other Reinforcement Learning Algorithms” by Wiering and van Hasselt (2009). (http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.713.1931)

Parameters
  • v_tm1 (Numeric) – state values at time t-1.

  • r_t (Numeric) – reward at time t.

  • discount_t (Numeric) – discount at time t.

  • q_t (Array) – Q-values at time t.

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Numeric

Returns

QVMAX temporal difference error.

Retrace

rlax.retrace(q_tm1, q_t, a_tm1, a_t, r_t, discount_t, pi_t, mu_t, lambda_, eps=1e-08, stop_target_gradients=True)[source]

Calculates Retrace errors.

See “Safe and Efficient Off-Policy Reinforcement Learning” by Munos et al. (https://arxiv.org/abs/1606.02647).

Parameters
  • q_tm1 (Array) – Q-values at time t-1.

  • q_t (Array) – Q-values at time t.

  • a_tm1 (Array) – action index at time t-1.

  • a_t (Array) – action index at time t.

  • r_t (Array) – reward at time t.

  • discount_t (Array) – discount at time t.

  • pi_t (Array) – target policy probs at time t.

  • mu_t (Array) – behavior policy probs at time t.

  • lambda – scalar mixing parameter lambda.

  • eps (float) – small value to add to mu_t for numerical stability.

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Array

Returns

Retrace error.

Retrace Continuous

rlax.retrace_continuous(q_tm1, q_t, v_t, r_t, discount_t, log_rhos, lambda_, stop_target_gradients=True)[source]

Retrace continuous.

See “Safe and Efficient Off-Policy Reinforcement Learning” by Munos et al. (https://arxiv.org/abs/1606.02647).

Parameters
  • q_tm1 (Array) – Q-values at times [0, …, K - 1].

  • q_t (Array) – Q-values evaluated at actions collected using behavior policy at times [1, …, K - 1].

  • v_t (Array) – Value estimates of the target policy at times [1, …, K].

  • r_t (Array) – reward at times [1, …, K].

  • discount_t (Array) – discount at times [1, …, K].

  • log_rhos (Array) – Log importance weight pi_target/pi_behavior evaluated at actions collected using behavior policy [1, …, K - 1].

  • lambda – scalar or a vector of mixing parameter lambda.

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Array

Returns

Retrace error.

SARSA

rlax.sarsa(q_tm1, a_tm1, r_t, discount_t, q_t, a_t, stop_target_gradients=True)[source]

Calculates the SARSA temporal difference error.

See “Reinforcement Learning: An Introduction” by Sutton and Barto. (http://incompleteideas.net/book/ebook/node64.html.)

Parameters
  • q_tm1 (Array) – Q-values at time t-1.

  • a_tm1 (Numeric) – action index at time t-1.

  • r_t (Numeric) – reward at time t.

  • discount_t (Numeric) – discount at time t.

  • q_t (Array) – Q-values at time t.

  • a_t (Numeric) – action index at time t.

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Numeric

Returns

SARSA temporal difference error.

SARSA Lambda

rlax.sarsa_lambda(q_tm1, a_tm1, r_t, discount_t, q_t, a_t, lambda_, stop_target_gradients=True)[source]

Calculates the SARSA(lambda) temporal difference error.

See “Reinforcement Learning: An Introduction” by Sutton and Barto. (http://incompleteideas.net/book/ebook/node77.html).

Parameters
  • q_tm1 (Array) – sequence of Q-values at time t-1.

  • a_tm1 (Array) – sequence of action indices at time t-1.

  • r_t (Array) – sequence of rewards at time t.

  • discount_t (Array) – sequence of discounts at time t.

  • q_t (Array) – sequence of Q-values at time t.

  • a_t (Array) – sequence of action indices at time t.

  • lambda – mixing parameter lambda, either a scalar or a sequence.

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Array

Returns

SARSA(lambda) temporal difference error.

TD Lambda

rlax.td_lambda(v_tm1, r_t, discount_t, v_t, lambda_, stop_target_gradients=True)[source]

Calculates the TD(lambda) temporal difference error.

See “Reinforcement Learning: An Introduction” by Sutton and Barto. (http://incompleteideas.net/book/ebook/node74.html).

Parameters
  • v_tm1 (Array) – sequence of state values at time t-1.

  • r_t (Array) – sequence of rewards at time t.

  • discount_t (Array) – sequence of discounts at time t.

  • v_t (Array) – sequence of state values at time t.

  • lambda – mixing parameter lambda, either a scalar or a sequence.

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Array

Returns

TD(lambda) temporal difference error.

TD Learning

rlax.td_learning(v_tm1, r_t, discount_t, v_t, stop_target_gradients=True)[source]

Calculates the TD-learning temporal difference error.

See “Learning to Predict by the Methods of Temporal Differences” by Sutton. (https://link.springer.com/article/10.1023/A:1022633531479).

Parameters
  • v_tm1 (Numeric) – state values at time t-1.

  • r_t (Numeric) – reward at time t.

  • discount_t (Numeric) – discount at time t.

  • v_t (Numeric) – state values at time t.

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Numeric

Returns

TD-learning temporal difference error.

Transformed General Off Policy Returns from Action Values

rlax.transformed_general_off_policy_returns_from_action_values(q_t, a_t, r_t, discount_t, c_t, pi_t, stop_target_gradients=False)[source]

Calculates targets for various off-policy correction algorithms.

Given a window of experience of length K, generated by a behaviour policy μ, for each time-step t we can estimate the return G_t from that step onwards, under some target policy π, using the rewards in the trajectory, the actions selected by μ and the action-values under π, according to equation:

Gₜ = rₜ₊₁ + γₜ₊₁ * (E[q(aₜ₊₁)] - cₜ * q(aₜ₊₁) + cₜ * Gₜ₊₁),

where, depending on the choice of c_t, the algorithm implements:

Importance Sampling c_t = π(x_t, a_t) / μ(x_t, a_t), Harutyunyan’s et al. Q(lambda) c_t = λ, Precup’s et al. Tree-Backup c_t = π(x_t, a_t), Munos’ et al. Retrace c_t = λ min(1, π(x_t, a_t) / μ(x_t, a_t)).

See “Safe and Efficient Off-Policy Reinforcement Learning” by Munos et al. (https://arxiv.org/abs/1606.02647).

Parameters
  • q_t (Array) – Q-values at times [1, …, K - 1].

  • a_t (Array) – action index at times [1, …, K - 1].

  • r_t (Array) – reward at times [1, …, K - 1].

  • discount_t (Array) – discount at times [1, …, K - 1].

  • c_t (Array) – importance weights at times [1, …, K - 1].

  • pi_t (Array) – target policy probs at times [1, …, K - 1].

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Array

Returns

Off-policy estimates of the generalized returns from states visited at times [0, …, K - 1].

Transformed Lambda Returns

rlax.transformed_lambda_returns(r_t, discount_t, v_t, lambda_=1.0, stop_target_gradients=False)[source]

Estimates a multistep truncated lambda return from a trajectory.

Given a a trajectory of length T+1, generated under some policy π, for each time-step t we can estimate a target return G_t, by combining rewards, discounts, and state values, according to a mixing parameter lambda.

The parameter lambda_ mixes the different multi-step bootstrapped returns, corresponding to accumulating k rewards and then bootstrapping using v_t.

rₜ₊₁ + γₜ₊₁ vₜ₊₁ rₜ₊₁ + γₜ₊₁ rₜ₊₂ + γₜ₊₁ γₜ₊₂ vₜ₊₂ rₜ₊₁ + γₜ₊₁ rₜ₊₂ + γₜ₊₁ γₜ₊₂ rₜ₊₂ + γₜ₊₁ γₜ₊₂ γₜ₊₃ vₜ₊₃

The returns are computed recursively, from G_{T-1} to G_0, according to:

Gₜ = rₜ₊₁ + γₜ₊₁ [(1 - λₜ₊₁) vₜ₊₁ + λₜ₊₁ Gₜ₊₁].

In the on-policy case, we estimate a return target G_t for the same policy π that was used to generate the trajectory. In this setting the parameter lambda_ is typically a fixed scalar factor. Depending on how values v_t are computed, this function can be used to construct targets for different multistep reinforcement learning updates:

TD(λ): v_t contains the state value estimates for each state under π. Q(λ): v_t = max(q_t, axis=-1), where q_t estimates the action values. Sarsa(λ): v_t = q_t[…, a_t], where q_t estimates the action values.

In the off-policy case, the mixing factor is a function of state, and different definitions of lambda implement different off-policy corrections:

Per-decision importance sampling: λₜ = λ ρₜ = λ [π(aₜ|sₜ) / μ(aₜ|sₜ)] V-trace, as instantiated in IMPALA: λₜ = min(1, ρₜ)

Note that the second option is equivalent to applying per-decision importance sampling, but using an adaptive λ(ρₜ) = min(1/ρₜ, 1), such that the effective bootstrap parameter at time t becomes λₜ = λ(ρₜ) * ρₜ = min(1, ρₜ). This is the interpretation used in the ABQ(ζ) algorithm (Mahmood 2017).

Of course this can be augmented to include an additional factor λ. For instance we could use V-trace with a fixed additional parameter λ = 0.9, by setting λₜ = 0.9 * min(1, ρₜ) or, alternatively (but not equivalently), λₜ = min(0.9, ρₜ).

Estimated return are then often used to define a td error, e.g.: ρₜ(Gₜ - vₜ).

See “Reinforcement Learning: An Introduction” by Sutton and Barto. (http://incompleteideas.net/sutton/book/ebook/node74.html).

Parameters
  • r_t (Array) – sequence of rewards rₜ for timesteps t in [1, T].

  • discount_t (Array) – sequence of discounts γₜ for timesteps t in [1, T].

  • v_t (Array) – sequence of state values estimates under π for timesteps t in [1, T].

  • lambda – mixing parameter; a scalar or a vector for timesteps t in [1, T].

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Array

Returns

Multistep lambda returns.

Transformed N Step Q Learning

rlax.transformed_n_step_q_learning(q_tm1, a_tm1, target_q_t, a_t, r_t, discount_t, n, stop_target_gradients=True, tx_pair=TxPair(apply=<function identity>, apply_inv=<function identity>))[source]

Calculates transformed n-step TD errors.

See “Recurrent Experience Replay in Distributed Reinforcement Learning” by Kapturowski et al. (https://openreview.net/pdf?id=r1lyTjAqYX).

Parameters
  • q_tm1 (Array) – Q-values at times [0, …, T - 1].

  • a_tm1 (Array) – action index at times [0, …, T - 1].

  • target_q_t (Array) – target Q-values at time [1, …, T].

  • a_t (Array) – action index at times [[1, … , T]] used to select target q-values to bootstrap from; max(target_q_t) for normal Q-learning, max(q_t) for double Q-learning.

  • r_t (Array) – reward at times [1, …, T].

  • discount_t (Array) – discount at times [1, …, T].

  • n (int) – number of steps over which to accumulate reward before bootstrapping.

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

  • tx_pair (TxPair) – TxPair of value function transformation and its inverse.

Return type

Array

Returns

Transformed N-step TD error.

Transformed N Step Returns

rlax.transformed_n_step_returns(r_t, discount_t, v_t, n, lambda_t=1.0, stop_target_gradients=False)[source]

Computes strided n-step bootstrapped return targets over a sequence.

The returns are computed according to the below equation iterated n times:

Gₜ = rₜ₊₁ + γₜ₊₁ [(1 - λₜ₊₁) vₜ₊₁ + λₜ₊₁ Gₜ₊₁].

When lambda_t == 1. (default), this reduces to

Gₜ = rₜ₊₁ + γₜ₊₁ * (rₜ₊₂ + γₜ₊₂ * (… * (rₜ₊ₙ + γₜ₊ₙ * vₜ₊ₙ ))).

Parameters
  • r_t (Array) – rewards at times [1, …, T].

  • discount_t (Array) – discounts at times [1, …, T].

  • v_t (Array) – state or state-action values to bootstrap from at time [1, …., T].

  • n (int) – number of steps over which to accumulate reward before bootstrapping.

  • lambda_t (Numeric) – lambdas at times [1, …, T]. Shape is [], or [T-1].

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Array

Returns

estimated bootstrapped returns at times [0, …., T-1]

Transformed Q Lambda

rlax.transformed_q_lambda(q_tm1, a_tm1, r_t, discount_t, q_t, lambda_, stop_target_gradients=True, tx_pair=TxPair(apply=<function identity>, apply_inv=<function identity>))[source]

Calculates Peng’s or Watkins’ Q(lambda) temporal difference error.

See “General non-linear Bellman equations” by van Hasselt et al. (https://arxiv.org/abs/1907.03687).

Parameters
  • q_tm1 (Array) – sequence of Q-values at time t-1.

  • a_tm1 (Array) – sequence of action indices at time t-1.

  • r_t (Array) – sequence of rewards at time t.

  • discount_t (Array) – sequence of discounts at time t.

  • q_t (Array) – sequence of Q-values at time t.

  • lambda – mixing parameter lambda, either a scalar (e.g. Peng’s Q(lambda)) or a sequence (e.g. Watkin’s Q(lambda)).

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

  • tx_pair (TxPair) – TxPair of value function transformation and its inverse.

Return type

Array

Returns

Q(lambda) temporal difference error.

Transformed Retrace


rlax.transformed_retrace(q_tm1, q_t, a_tm1, a_t, r_t, discount_t, pi_t, mu_t, lambda_, eps=1e-08, stop_target_gradients=True, tx_pair=TxPair(apply=<function identity>, apply_inv=<function identity>))[source]

Calculates transformed Retrace errors.

See “Recurrent Experience Replay in Distributed Reinforcement Learning” by Kapturowski et al. (https://openreview.net/pdf?id=r1lyTjAqYX).

Parameters
  • q_tm1 (Array) – Q-values at time t-1.

  • q_t (Array) – Q-values at time t.

  • a_tm1 (Array) – action index at time t-1.

  • a_t (Array) – action index at time t.

  • r_t (Array) – reward at time t.

  • discount_t (Array) – discount at time t.

  • pi_t (Array) – target policy probs at time t.

  • mu_t (Array) – behavior policy probs at time t.

  • lambda – scalar mixing parameter lambda.

  • eps (float) – small value to add to mu_t for numerical stability.

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

  • tx_pair (TxPair) – TxPair of value function transformation and its inverse.

Return type

Array

Returns

Transformed Retrace error.

Truncated Generalized Advantage Estimation

rlax.truncated_generalized_advantage_estimation(r_t, discount_t, lambda_, values, stop_target_gradients=False)[source]

Computes truncated generalized advantage estimates for a sequence length k.

The advantages are computed in a backwards fashion according to the equation: Âₜ = δₜ + (γλ) * δₜ₊₁ + … + … + (γλ)ᵏ⁻ᵗ⁺¹ * δₖ₋₁ where δₜ = rₜ₊₁ + γₜ₊₁ * v(sₜ₊₁) - v(sₜ).

See Proximal Policy Optimization Algorithms, Schulman et al.: https://arxiv.org/abs/1707.06347

  • Note: This paper uses a different notation than the RLax standard

convention that follows Sutton & Barto. We use rₜ₊₁ to denote the reward received after acting in state sₜ, while the PPO paper uses rₜ.

Parameters
  • r_t (Array) – Sequence of rewards at times [1, k]

  • discount_t (Array) – Sequence of discounts at times [1, k]

  • lambda – Mixing parameter; a scalar or sequence of lambda_t at times [1, k]

  • values (Array) – Sequence of values under π at times [0, k]

  • stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Array

Returns

Multistep truncated generalized advantage estimation at times [0, k-1].

VTrace

rlax.vtrace(v_tm1, v_t, r_t, discount_t, rho_tm1, lambda_=1.0, clip_rho_threshold=1.0, stop_target_gradients=True)[source]

Calculates V-Trace errors from importance weights.

V-trace computes TD-errors from multistep trajectories by applying off-policy corrections based on clipped importance sampling ratios.

See “IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor Learner Architectures” by Espeholt et al. (https://arxiv.org/abs/1802.01561).

Parameters
  • v_tm1 (Array) – values at time t-1.

  • v_t (Array) – values at time t.

  • r_t (Array) – reward at time t.

  • discount_t (Array) – discount at time t.

  • rho_tm1 (Array) – importance sampling ratios at time t-1.

  • lambda – scalar mixing parameter lambda.

  • clip_rho_threshold (float) – clip threshold for importance weights.

  • stop_target_gradients (bool) – whether or not to apply stop gradient to targets.

Return type

Array

Returns

V-Trace error.

Policy Optimization

clipped_surrogate_pg_loss(prob_ratios_t, …)

Computes the clipped surrogate policy gradient loss.

dpg_loss(a_t, dqda_t[, dqda_clipping, …])

Calculates the deterministic policy gradient (DPG) loss.

entropy_loss(logits_t, w_t)

Calculates the entropy regularization loss.

LagrangePenalty(alpha, epsilon, per_dimension)

mpo_loss(sample_log_probs, sample_q_values, …)

Implements the MPO loss with a KL bound.

mpo_compute_weights_and_temperature_loss(…)

Computes the weights and temperature loss for MPO.

policy_gradient_loss(logits_t, a_t, adv_t, w_t)

Calculates the policy gradient loss.

qpg_loss(logits_t, q_t[, use_stop_gradient])

Computes the QPG (Q-based Policy Gradient) loss.

rm_loss(logits_t, q_t[, use_stop_gradient])

Computes the RMPG (Regret Matching Policy Gradient) loss.

rpg_loss(logits_t, q_t[, use_stop_gradient])

Computes the RPG (Regret Policy Gradient) loss.

Clipped Surrogate PG Loss

rlax.clipped_surrogate_pg_loss(prob_ratios_t, adv_t, epsilon, use_stop_gradient=True)[source]

Computes the clipped surrogate policy gradient loss.

L_clipₜ(θ) = - min(rₜ(θ)Âₜ, clip(rₜ(θ), 1-ε, 1+ε)Âₜ)

Where rₜ(θ) = π_θ(aₜ| sₜ) / π_θ_old(aₜ| sₜ) and Âₜ are the advantages.

See Proximal Policy Optimization Algorithms, Schulman et al.: https://arxiv.org/abs/1707.06347

Parameters
  • prob_ratios_t (Array) – Ratio of action probabilities for actions a_t: rₜ(θ) = π_θ(aₜ| sₜ) / π_θ_old(aₜ| sₜ)

  • adv_t (Array) – the observed or estimated advantages from executing actions a_t.

  • epsilon (Scalar) – Scalar value corresponding to how much to clip the objecctive.

  • use_stop_gradient – bool indicating whether or not to apply stop gradient to advantages.

Return type

Array

Returns

Loss whose gradient corresponds to a clipped surrogate policy gradient

update.

Compute Parametric KL Penalty and Dual Loss

rlax.compute_parametric_kl_penalty_and_dual_loss(kl_constraints, projection_operator, use_stop_gradient=True)[source]

Optimize hard KL constraints between the current and previous policies.

Return type

Tuple[Array, Array]

DPG Loss

rlax.dpg_loss(a_t, dqda_t, dqda_clipping=None, use_stop_gradient=True)[source]

Calculates the deterministic policy gradient (DPG) loss.

See “Deterministic Policy Gradient Algorithms” by Silver, Lever, Heess, Degris, Wierstra, Riedmiller (http://proceedings.mlr.press/v32/silver14.pdf).

Parameters
  • a_t (Array) – continuous-valued action at time t.

  • dqda_t (Array) – gradient of Q(s,a) wrt. a, evaluated at time t.

  • dqda_clipping (Optional[Scalar]) – clips the gradient to have norm <= dqda_clipping.

  • use_stop_gradient (bool) – bool indicating whether or not to apply stop gradient to targets.

Return type

Array

Returns

DPG loss.

Entropy Loss

rlax.entropy_loss(logits_t, w_t)[source]

Calculates the entropy regularization loss.

See “Function Optimization using Connectionist RL Algorithms” by Williams. (https://www.tandfonline.com/doi/abs/10.1080/09540099108946587)

Parameters
  • logits_t (Array) – a sequence of unnormalized action preferences.

  • w_t (Array) – a per timestep weighting for the loss.

Return type

Array

Returns

Entropy loss.

Lagrange Penalty

class rlax.LagrangePenalty(alpha, epsilon, per_dimension)[source]
property alpha

Alias for field number 0

property epsilon

Alias for field number 1

property per_dimension

Alias for field number 2

__getnewargs__()[source]

Return self as a plain tuple. Used by copy and pickle.

MPO Compute Weights and Temperature Loss

rlax.mpo_compute_weights_and_temperature_loss(sample_q_values, temperature_constraint, projection_operator, sample_axis=0)[source]

Computes the weights and temperature loss for MPO.

The E-Step computes a non-parameteric sample-based approximation of the current policy by reweighting the state-action value function.

Here, we compute this nonparametric policy and optimize the temperature parameter used in the reweighting.

Parameters
  • sample_q_values (Array) – An array of shape E* + a sample axis inserted at sample_axis containing the q function values evaluated on the sampled actions.

  • temperature_constraint (LagrangePenalty) – Lagrange constraint for the E-step temperature optimization.

  • projection_operator (Callable[[Numeric], Numeric]) – Function to project temperature into the positive range.

  • sample_axis (int) – Axis in sample_q_values containing sampled actions.

Return type

Tuple[Array, Array, Scalar]

Returns

The temperature loss, normalized weights and number of actions samples per state.

MPO Loss

rlax.mpo_loss(sample_log_probs, sample_q_values, temperature_constraint, kl_constraints, projection_operator=functools.partial(<CompiledFunction of <function clip>>, a_min=1e-10), policy_loss_weight=1.0, temperature_loss_weight=1.0, kl_loss_weight=1.0, alpha_loss_weight=1.0, sample_axis=0, use_stop_gradient=True)[source]

Implements the MPO loss with a KL bound.

This loss implements the MPO algorithm for policies with a bound for the KL between the current and target policy.

Note: This is a per-example loss which works on any shape inputs as long as they are consistent. We denote this shape E* for ease of reference. Args sample_log_probs and sample_q_values are shape E + an extra sample axis that contains the sampled actions’ log probs and q values respectively. For example, if sample_axis = 0, the shapes expected will be [S, E*]. Or if E* = [T, B] and sample_axis = 1, the shapes expected will be [T, S, B].

Parameters
  • sample_log_probs (Array) – An array of shape E* + a sample axis inserted at sample_axis containing the log probabilities of the sampled actions under the current policy.

  • sample_q_values (Array) – An array of shape E* + a sample axis inserted at sample_axis containing the q function values evaluated on the sampled actions.

  • temperature_constraint (LagrangePenalty) – Lagrange constraint for the E-step temperature optimization.

  • kl_constraints (Sequence[Tuple[Array, LagrangePenalty]]) – KL and variables for applying Lagrangian penalties to bound them in the M-step, KLs are [E*, A?]. Here A is the action dimension in the case of per-dimension KL constraints.

  • projection_operator (Callable[[Numeric], Numeric]) – Function to project dual variables (temperature and kl constraint alphas) into the positive range.

  • policy_loss_weight (float) – Weight for the policy loss.

  • temperature_loss_weight (float) – Weight for the temperature loss.

  • kl_loss_weight (float) – Weight for the KL loss.

  • alpha_loss_weight (float) – Weight for the alpha loss.

  • sample_axis (int) – Axis in sample_log_probs and sample_q_values that contains the sampled actions’ log probs and q values respectively. For example, if sample_axis = 0, the shapes expected will be [S, E*]. Or if E* = [T, B] and sample_axis = 1, the shapes expected will be [T, S, B].

  • use_stop_gradient (bool) – bool indicating whether or not to apply stop gradient.

Return type

Tuple[Array, MpoOutputs]

Returns

Per example loss with shape E*, and additional data including the components of this loss and the normalized weights in the AdditionalOutputs.

Policy Gradient Loss

rlax.policy_gradient_loss(logits_t, a_t, adv_t, w_t, use_stop_gradient=True)[source]

Calculates the policy gradient loss.

See “Simple Gradient-Following Algorithms for Connectionist RL” by Williams. (http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf)

Parameters
  • logits_t (Array) – a sequence of unnormalized action preferences.

  • a_t (Array) – a sequence of actions sampled from the preferences logits_t.

  • adv_t (Array) – the observed or estimated advantages from executing actions a_t.

  • w_t (Array) – a per timestep weighting for the loss.

  • use_stop_gradient (bool) – bool indicating whether or not to apply stop gradient to advantages.

Return type

Array

Returns

Loss whose gradient corresponds to a policy gradient update.

QPG Loss

rlax.qpg_loss(logits_t, q_t, use_stop_gradient=True)[source]

Computes the QPG (Q-based Policy Gradient) loss.

See “Actor-Critic Policy Optimization in Partially Observable Multiagent Environments” by Srinivasan, Lanctot (https://arxiv.org/abs/1810.09026).

Parameters
  • logits_t (Array) – a sequence of unnormalized action preferences.

  • q_t (Array) – the observed or estimated action value from executing actions a_t at time t.

  • use_stop_gradient (bool) – bool indicating whether or not to apply stop gradient to advantages.

Return type

Array

Returns

QPG Loss.

RM Loss

rlax.rm_loss(logits_t, q_t, use_stop_gradient=True)[source]

Computes the RMPG (Regret Matching Policy Gradient) loss.

The gradient of this loss adapts the Regret Matching rule by weighting the standard PG update with thresholded regret.

See “Actor-Critic Policy Optimization in Partially Observable Multiagent Environments” by Srinivasan, Lanctot (https://arxiv.org/abs/1810.09026).

Parameters
  • logits_t (Array) – a sequence of unnormalized action preferences.

  • q_t (Array) – the observed or estimated action value from executing actions a_t at time t.

  • use_stop_gradient (bool) – bool indicating whether or not to apply stop gradient to advantages.

Return type

Array

Returns

RM Loss.

RPG Loss

rlax.rpg_loss(logits_t, q_t, use_stop_gradient=True)[source]

Computes the RPG (Regret Policy Gradient) loss.

The gradient of this loss adapts the Regret Matching rule by weighting the standard PG update with regret.

See “Actor-Critic Policy Optimization in Partially Observable Multiagent Environments” by Srinivasan, Lanctot (https://arxiv.org/abs/1810.09026).

Parameters
  • logits_t (Array) – a sequence of unnormalized action preferences.

  • q_t (Array) – the observed or estimated action value from executing actions a_t at time t.

  • use_stop_gradient (bool) – bool indicating whether or not to apply stop gradient to advantages.

Return type

Array

Returns

RPG Loss.

MPO Compute Weights and Temperature Loss

rlax.vmpo_compute_weights_and_temperature_loss(advantages, restarting_weights, importance_weights, temperature_constraint, projection_operator, top_k_fraction, axis_name=None, use_stop_gradient=True)[source]

Computes the weights and temperature loss for V-MPO.

Parameters
  • advantages (Array) – Advantages for the E-step. Shape E*.

  • restarting_weights (Array) – Restarting weights, 0 means that this step is the start of a new episode and we ignore losses at this step because the agent cannot influence these. Shape E*.

  • importance_weights (Array) – Optional importance weights. Shape E*

  • temperature_constraint (LagrangePenalty) – Lagrange constraint for the E-step temperature optimization.

  • projection_operator (Callable[[Numeric], Numeric]) – Function to project dual variables (temperature and kl constraint alphas) into the positive range.

  • top_k_fraction (float) – Fraction of samples to use in the E-step.

  • axis_name (Optional[str]) – Optional axis name for pmap or ‘vmap’. If None, computations are performed locally on each device.

  • use_stop_gradient (bool) – bool indicating whether or not to apply stop gradient.

Return type

Tuple[Scalar, Array, Scalar]

Returns

The temperature loss, normalized weights and number of samples used.

VMPO Loss

rlax.vmpo_loss(sample_log_probs, advantages, temperature_constraint, kl_constraints, projection_operator=functools.partial(<CompiledFunction of <function clip>>, a_min=1e-10), restarting_weights=None, importance_weights=None, top_k_fraction=0.5, policy_loss_weight=1.0, temperature_loss_weight=1.0, kl_loss_weight=1.0, alpha_loss_weight=1.0, axis_name=None, use_stop_gradient=True)[source]

Calculates the V-MPO policy improvement loss.

Note: This is a per-example loss which works on any shape inputs as long as they are consistent. We denote the shape of the examples E* for ease of reference.

Parameters
  • sample_log_probs (Array) – Log probabilities of actions for each example. Shape E*.

  • advantages (Array) – Advantages for the E-step. Shape E*.

  • temperature_constraint (LagrangePenalty) – Lagrange constraint for the E-step temperature optimization.

  • kl_constraints (Sequence[Tuple[Array, LagrangePenalty]]) – KL and variables for applying Lagrangian penalties to bound them in the M-step, KLs are E* or [E*, A]. Here A is the action dimension in the case of per-dimension KL constraints.

  • projection_operator (Callable[[Numeric], Numeric]) – Function to project dual variables (temperature and kl constraint alphas) into the positive range.

  • restarting_weights (Optional[Array]) – Optional restarting weights, shape E*, 0 means that this step is the start of a new episode and we ignore losses at this step because the agent cannot influence these.

  • importance_weights (Optional[Array]) – Optional importance weights, shape E*.

  • top_k_fraction (float) – Fraction of samples to use in the E-step.

  • policy_loss_weight (float) – Weight for the policy loss.

  • temperature_loss_weight (float) – Weight for the temperature loss.

  • kl_loss_weight (float) – Weight for the KL loss.

  • alpha_loss_weight (float) – Weight for the alpha loss.

  • axis_name (Optional[str]) – Optional axis name for pmap. If None, computations are performed locally on each device.

  • use_stop_gradient (bool) – bool indicating whether or not to apply stop gradient.

Return type

Tuple[Array, MpoOutputs]

Returns

Per example loss with same shape E* as array inputs, and additional data including the components of this loss and the normalized weights in the AdditionalOutputs.

Exploration

add_dirichlet_noise(key, prior, …)

Returns discrete actions with noise drawn from a Dirichlet distribution.

add_gaussian_noise(key, action, stddev)

Returns continuous action with noise drawn from a Gaussian distribution.

add_ornstein_uhlenbeck_noise(key, action, …)

Returns continuous action with noise from Ornstein-Uhlenbeck process.

episodic_memory_intrinsic_rewards(…[, …])

Compute intrinsic rewards for exploration via episodic memory.

Add Dirichlet Noise

rlax.add_dirichlet_noise(key, prior, dirichlet_alpha, dirichlet_fraction)[source]

Returns discrete actions with noise drawn from a Dirichlet distribution.

See “Mastering the Game of Go without Human Knowledge” by Silver et. al. 2017 (https://discovery.ucl.ac.uk/id/eprint/10045895/1/agz_unformatted_nature.pdf), “A General Reinforcement Learning Algorithm that Masters Chess, Shogi and Go Through Self-Play” by Silver et. al. 2018 (http://airesearch.com/wp-content/uploads/2016/01/deepmind-mastering-go.pdf), and “Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model” by Schrittwieser et. al., 2019 (https://arxiv.org/abs/1911.08265).

The AlphaZero family of algorithms adds noise sampled from a symmetric Dirichlet distribution to the prior policy generated by MCTS. Because the agent then samples from this new, noisy prior over actions, this encourages better exploration of the root node’s children.

Specifically, this computes:

noise ~ Dirichlet(alpha) noisy_prior = (1 - fraction) * prior + fraction * noise

Note that alpha is a single float to draw from a symmetric Dirichlet.

For reference values, AlphaZero uses 0.3, 0.15, 0.03 for Chess, Shogi, and Go respectively, and MuZero uses 0.25 for Atari.

Parameters
  • key (Array) – a key from jax.random.

  • prior (Array) – 2-dim continuous prior policy vector of shapes [B, N], for B batch size and N num_actions.

  • dirichlet_alpha (float) – concentration parameter to parametrize Dirichlet distribution.

  • dirichlet_fraction (float) – float from 0 to 1 interpolating between using only the prior policy or just the noise.

Return type

Array

Returns

noisy action, of the same shape as input action.

Add Gaussian Noise

rlax.add_gaussian_noise(key, action, stddev)[source]

Returns continuous action with noise drawn from a Gaussian distribution.

Parameters
  • key (Array) – a key from jax.random.

  • action (Array) – continuous action scalar or vector.

  • stddev (float) – standard deviation of noise distribution.

Return type

Array

Returns

noisy action, of the same shape as input action.

Add Ornstein Uhlenbeck Noise

rlax.add_ornstein_uhlenbeck_noise(key, action, noise_tm1, damping, stddev)[source]

Returns continuous action with noise from Ornstein-Uhlenbeck process.

See “On the theory of Brownian Motion” by Uhlenbeck and Ornstein. (https://journals.aps.org/pr/abstract/10.1103/PhysRev.36.823).

Parameters
  • key (Array) – a key from jax.random.

  • action (Array) – continuous action scalar or vector.

  • noise_tm1 (Array) – noise sampled from OU process in previous timestep.

  • damping (float) – parameter for controlling autocorrelation of OU process.

  • stddev (float) – standard deviation of noise distribution.

Return type

Array

Returns

noisy action, of the same shape as input action.

Episodic Memory Intrinsic Rewards

rlax.episodic_memory_intrinsic_rewards(embeddings, num_neighbors, reward_scale, intrinsic_reward_state=None, constant=0.001, epsilon=0.0001, cluster_distance=0.008, max_similarity=8.0, max_memory_size=30000)[source]

Compute intrinsic rewards for exploration via episodic memory.

This method is adopted from the intrinsic reward computation used in “Never Give Up: Learning Directed Exploration Strategies” by Puigdomènech Badia et al., (2020) (https://arxiv.org/abs/2003.13350) and “Agent57: Outperforming the Atari Human Benchmark” by Puigdomènech Badia et al., (2020) (https://arxiv.org/abs/2002.06038).

From an embedding, we compute the intra-episode intrinsic reward with respect to a pre-existing set of embeddings.

NOTE: For this function to be jittable, static_argnums=[1,] must be passed, as the internal jax.lax.top_k(neg_distances, num_neighbors) computation in knn_query cannot be jitted with a dynamic num_neighbors that is passed as an argument.

Parameters
  • embeddings (Array) – Array, shaped [M, D] for number of new state embeddings M and feature dim D.

  • num_neighbors (int) – int for K neighbors used in kNN query

  • reward_scale (float) – The β term used in the Agent57 paper to scale the reward.

  • intrinsic_reward_state (Optional[IntrinsicRewardState]) –

    An IntrinsicRewardState namedtuple, containing: - memory: Array; an array of previous memories within an episode padded

    with zeros up to max_memory_size.

    • next_memory_index: The index in the static memory array to add next

      embeddings at. Is updated in a ring buffer fashion.

    • distance_sum: Scalar, Optional; running sum of total negative squared distances computed by consecutive kNN queries used to compute mean distance.

    • distance_count: Scalar, Optional; running count of total negative squared distances computed by consecutive kNN queries used to compute mean distance.

    NOTE- On (only) the first call to episodic_memory_intrinsic_rewards, the intrinsic_reward_state is optional, if None is given, an IntrinsicRewardState will be initialized with default parameters, specifically, the memory will be initialized to an array of jnp.inf of shape [max_memory_size x feature dim D], and default values of 0 will be provided for next_memory_index, distance_sum, and distance_count.

  • constant (float) – float; small constant used for numerical stability used during normalizing distances.

  • epsilon (float) – float; small constant used for numerical stability when computing kernel output.

  • cluster_distance (float) – float; the ξ term used in the Agent57 paper to bound the distance rate used in the kernel computation.

  • max_similarity (float) – float; max limit of similarity; used to zero rewards when similarity between memories is too high to be considered ‘useful’ for an agent.

  • max_memory_size (int) – int; the maximum number of memories to store. Note that performance will be marginally faster if max_memory_size is an exact multiple of M (the number of embeddings to add to memory per call to episodic_memory_intrinsic_reward).

Returns

Array, shaped [M, 1]; Intrinsic reward for each embedding computed

by using similarity measure to memories.

intrinsic_reward_state: An IntrinsicRewardState namedtuple, containing:
  • memory: Array; an array of previous memories within an episode padded

    with zeros up to max_memory_size.

  • next_memory_index: The index in the static memory array to add next

    embeddings at. Is updated in a ring buffer fashion.

  • distance_sum: Scalar, Optional; running sum of total negative squared distances computed by consecutive kNN queries used to compute mean distance.

  • distance_count: Scalar, Optional; running count of total negative squared distances computed by consecutive kNN queries used to compute mean distance.

Return type

reward

KNN Query

Utilities

AllSum([axis_name])

Helper for summing over elements in an array and over devices.

batched_index(values, indices[, keepdims])

Index into the last dimension of a tensor, preserving all others dims.

clip_gradient

lhs_broadcast(source, target)

Ensures that source is compatible with target for broadcasting.

one_hot(indices, num_classes[, dtype])

Returns a one-hot version of indices.

embed_oar(features, action, reward, num_actions)

Embed each of the (observation, action, reward) inputs & concatenate.

tree_map_zipped(fn, nests)

Map a function over a list of identical nested structures.

tree_select(pred, on_true, on_false)

Select either one of two identical nested structs based on condition.

tree_split_key(rng_key, tree_like)

Generate random keys for each leaf in a tree.

tree_split_leaves(tree_like[, axis, keepdim])

Splits a tree of arrays into an array of trees avoiding data copying.

conditional_update(new_tensors, old_tensors, …)

Checks whether to update the params and returns the correct params.

incremental_update(new_tensors, old_tensors, tau)

Incrementally update all elements from a nested struct.

periodic_update(new_tensors, old_tensors, …)

Periodically switch all elements from a nested struct with new elements.

All Sum

class rlax.AllSum(axis_name=None)[source]

Helper for summing over elements in an array and over devices.

__init__(axis_name=None)[source]

Sums locally and then over devices with the axis name provided.

__call__(x, axis=None)[source]

Call self as a function.

Return type

Numeric

Batched Index

rlax.batched_index(values, indices, keepdims=False)[source]

Index into the last dimension of a tensor, preserving all others dims.

Parameters
  • values (Array) – a tensor of shape […, D],

  • indices (Array) – indices of shape […].

  • keepdims (bool) – whether to keep the final dimension.

Return type

Array

Returns

a tensor of shape […] or […, 1].

Clip Gradient

rlax.clip_gradient(*args, **kwargs)[source]

LHS Broadcast

rlax.lhs_broadcast(source, target)[source]

Ensures that source is compatible with target for broadcasting.

One Hot

rlax.one_hot(indices, num_classes, dtype=<class 'jax._src.numpy.lax_numpy.float32'>)[source]

Returns a one-hot version of indices.

Parameters
  • indices – A tensor of indices.

  • num_classes – Number of classes in the one-hot dimension.

  • dtype – The dtype.

Returns

The one-hot tensor. If indices’ shape is [A, B, …], shape is

[A, B, …, num_classes].

Embed OAR

rlax.embed_oar(features, action, reward, num_actions)[source]

Embed each of the (observation, action, reward) inputs & concatenate.

Return type

Array

Tree Map Zipped

rlax.tree_map_zipped(fn, nests)[source]

Map a function over a list of identical nested structures.

Parameters
  • fn (Callable[.., Any]) – the function to map; must have arity equal to len(list_of_nests).

  • nests (Sequence[Any]) – a list of identical nested structures.

Returns

a nested structure whose leaves are outputs of applying fn.

Tree Select

rlax.tree_select(pred, on_true, on_false)[source]

Select either one of two identical nested structs based on condition.

Parameters
  • pred (Array) – a boolean condition.

  • on_true (Any) – an arbitrary nested structure.

  • on_false (Any) – a nested structure identical to on_true.

Returns

the selected nested structure.

Tree Split Key

rlax.tree_split_key(rng_key, tree_like)[source]

Generate random keys for each leaf in a tree.

Parameters
  • rng_key (Array) – a JAX pseudo random number generator key.

  • tree_like (Any) – a nested structure.

Returns

a new key, and a tree of keys with same shape as tree_like.

Tree Split Leaves

rlax.tree_split_leaves(tree_like, axis=0, keepdim=False)[source]

Splits a tree of arrays into an array of trees avoiding data copying.

Note: jax.numpy.DeviceArray’s data gets copied.

Parameters
  • tree_like (Any) – a nested object with leaves to split.

  • axis (int) – an axis for splitting.

  • keepdim (bool) – a bool indicating whether to keep axis dimension.

Returns

A tuple of size(axis) trees containing results of splitting.

Conditional Update

rlax.conditional_update(new_tensors, old_tensors, is_time)[source]

Checks whether to update the params and returns the correct params.

Incremental Update

rlax.incremental_update(new_tensors, old_tensors, tau)[source]

Incrementally update all elements from a nested struct.

Periodic Update

rlax.periodic_update(new_tensors, old_tensors, steps, update_period)[source]

Periodically switch all elements from a nested struct with new elements.

General Value Functions

pixel_control_rewards(observations, cell_size)

Calculates cumulants for pixel control tasks from an observation sequence.

feature_control_rewards(features[, …])

Calculates cumulants for feature control tasks from a sequence of features.

Pixel Control Rewards

rlax.pixel_control_rewards(observations, cell_size)[source]

Calculates cumulants for pixel control tasks from an observation sequence.

The observations are first split in a grid of KxK cells. For each cell a distinct pseudo reward is computed as the average absolute change in pixel intensity across all pixels in the cell. The change in intensity is averaged across both pixels and channels (e.g. RGB).

The observations provided to this function should be cropped suitably, to ensure that the observations’ height and width are a multiple of cell_size. The values of the observations tensor should be rescaled to [0, 1].

See “Reinforcement Learning with Unsupervised Auxiliary Tasks” by Jaderberg, Mnih, Czarnecki et al. (https://arxiv.org/abs/1611.05397).

Parameters
  • observations (Array) – A tensor of shape [T+1,H,W,C], where * T is the sequence length, * H is height, * W is width, * C is a channel dimension.

  • cell_size (int) – The size of each cell.

Return type

Array

Returns

A tensor of pixel control rewards calculated from the observation. The shape is [T,H’,W’], where H’=H/cell_size and W’=W/cell_size.

Feature Control Rewards

rlax.feature_control_rewards(features, cumulant_type='absolute_change', discount=None)[source]

Calculates cumulants for feature control tasks from a sequence of features.

For each feature dimension, a distinct pseudo reward is computed based on the change in the feature value between consecutive timesteps. Depending on cumulant_type, cumulants may be equal the features themselves, the absolute difference between their values in consecutive steps, their increase/decrease, or may take the form of a potential-based reward discounted by discount.

See “Reinforcement Learning with Unsupervised Auxiliary Tasks” by Jaderberg, Mnih, Czarnecki et al. (https://arxiv.org/abs/1611.05397).

Parameters
  • features (Array) – A tensor of shape [T+1,D] of features.

  • cumulant_type – either ‘feature’ (feature is the reward), absolute_change (the reward equals the absolute difference between consecutive timesteps), increase (the reward equals the increase in the value of the feature), decrease (the reward equals the decrease in the value of the feature), or ‘potential’ (r=gamma*phi_{t+1} - phi_t).

  • discount – (optional) discount for potential based rewards.

Return type

Array

Returns

A tensor of cumulants calculated from the features. The shape is [T,D].

Pop Art

art(state, targets, indices, step_size, …)

Adaptively rescale targets.

normalize(state, unnormalized, indices)

Returns normalized values.

pop(params, old, new)

Preserves outputs precisely.

popart(num_outputs, step_size, scale_lb, …)

Generates functions giving initial PopArt state and update rule.

PopArtState(shift, scale, second_moment)

unnormalize(state, normalized, indices)

Returns unnormalized values.

unnormalize_linear(state, inputs, indices)

Selects and unnormalizes output of a Linear.

Art

rlax.art(state, targets, indices, step_size, scale_lb, scale_ub, axis_name=None)[source]

Adaptively rescale targets.

Parameters
  • state (PopArtState) – The PopArt summary stats.

  • targets (Array) – targets which are rescaled.

  • indices (Array) – Which indices of the state to use.

  • step_size (float) – The step size for learning the scale & shift parameters.

  • scale_lb (float) – Lower bound for the scale.

  • scale_ub (float) – Upper bound for the scale.

  • axis_name – What axis to aggregate over, if str. If passed an iterable, aggregates over multiple axes. Defaults to no aggregation, i.e. None.

Return type

PopArtState

Returns

New popart state which can be used to rescale targets.

Normalize

rlax.normalize(state, unnormalized, indices)[source]

Returns normalized values.

Parameters
  • state (PopArtState) – The PopArt summary stats.

  • unnormalized (Array) – unnormalized values that we applied PopArt to.

  • indices (Array) – Which scale and shifts to use

Return type

Array

Returns

Normalized PopArt values.

Pop

rlax.pop(params, old, new)[source]

Preserves outputs precisely.

Parameters
  • params (LinearParams) – The parameters of the linear to preserve.

  • old (PopArtState) – The old PopArt state.

  • new (PopArtState) – The new PopArt state.

Returns

new parameters.

PopArt

rlax.popart(num_outputs, step_size, scale_lb, scale_ub, axis_name=None)[source]

Generates functions giving initial PopArt state and update rule.

Parameters
  • num_outputs (int) – The number of outputs generated by the linear we’re preserving.

  • step_size (float) – The step size for learning the scale & shift parameters.

  • scale_lb (float) – Lower bound for the scale.

  • scale_ub (float) – Upper bound for the scale.

  • axis_name – What axis to aggregate over, if str. If passed an iterable, aggregates over multiple axes. Defaults to no aggregation, i.e. None.

Returns

initial_state: A function returning the initial PopArt state. popart_update: A function updating the PopArt state and parameters

of the preceding linear.

Return type

A tuple of

PopArtState

class rlax.PopArtState(shift, scale, second_moment)
__getnewargs__()[source]

Return self as a plain tuple. Used by copy and pickle.

static __new__(_cls, shift, scale, second_moment)

Create new instance of PopArtState(shift, scale, second_moment)

property scale

Alias for field number 1

property second_moment

Alias for field number 2

property shift

Alias for field number 0

Unnormalize

rlax.unnormalize(state, normalized, indices)[source]

Returns unnormalized values.

Parameters
  • state (PopArtState) – The PopArt summary stats.

  • normalized (Array) – normalized values that we apply PopArt to.

  • indices (Array) – Which scale and shifts to use

Return type

Array

Returns

Unnormalized PopArt values.

Unnormalize Linear

rlax.unnormalize_linear(state, inputs, indices)[source]

Selects and unnormalizes output of a Linear.

Parameters
  • state (PopArtState) – The PopArt summary stats.

  • inputs (Array) – The (normalized) output of the Linear that we apply PopArt to.

  • indices (Array) – Which indices of inputs to use.

Return type

PopArtOutput

Returns

PopArtOutput, a tuple of the normalized and unnormalized PopArt values.

Transforms

HYPERBOLIC_SIN_PAIR

TxPair(apply, apply_inv)

identity(x)

Identity transform.

IDENTITY_PAIR

TxPair(apply, apply_inv)

logit(x)

Logit transform, inverse of sigmoid.

power(x, p)

Power transform; power_tx(_, 1/p) is the inverse of power_tx(_, p).

sigmoid(x)

Sigmoid transform.

signed_expm1(x)

Signed exponential of x - 1, inverse of signed_logp1.

signed_hyperbolic(x[, eps])

Signed hyperbolic transform, inverse of signed_parabolic.

SIGNED_HYPERBOLIC_PAIR

TxPair(apply, apply_inv)

signed_logp1(x)

Signed logarithm of x + 1.

SIGNED_LOGP1_PAIR

TxPair(apply, apply_inv)

signed_parabolic(x[, eps])

Signed parabolic transform, inverse of signed_hyperbolic.

transform_from_2hot(probs, min_value, …)

Transforms from a categorical distribution to a scalar.

transform_to_2hot(scalar, min_value, …)

Transforms a scalar tensor to a 2 hot representation.

TxPair(apply, apply_inv)

Identity

rlax.identity(x)[source]

Identity transform.

Return type

Array

Logit

rlax.logit(x)[source]

Logit transform, inverse of sigmoid.

Return type

Array

Power

rlax.power(x, p)[source]

Power transform; power_tx(_, 1/p) is the inverse of power_tx(_, p).

Return type

Array

Sigmoid

rlax.sigmoid(x)[source]

Sigmoid transform.

Return type

Array

Signed Exponential

rlax.signed_expm1(x)[source]

Signed exponential of x - 1, inverse of signed_logp1.

Return type

Array

Signed Hyperbolic

rlax.signed_hyperbolic(x, eps=0.001)[source]

Signed hyperbolic transform, inverse of signed_parabolic.

Return type

Array

Signed Logarithm

rlax.signed_logp1(x)[source]

Signed logarithm of x + 1.

Return type

Array

Signed Parabolic

rlax.signed_parabolic(x, eps=0.001)[source]

Signed parabolic transform, inverse of signed_hyperbolic.

Return type

Array

Transform from 2 Hot

rlax.transform_from_2hot(probs, min_value, max_value, num_bins)[source]

Transforms from a categorical distribution to a scalar.

Return type

Array

Transform to 2 Hot

rlax.transform_to_2hot(scalar, min_value, max_value, num_bins)[source]

Transforms a scalar tensor to a 2 hot representation.

Return type

Array

Losses

l2_loss(predictions[, targets])

Caculates the L2 loss of predictions wrt targets.

likelihood(predictions, targets)

Calculates the likelihood of predictions wrt targets.

log_loss(predictions, targets)

Calculates the log loss of predictions wrt targets.

huber_loss(x[, delta])

Huber loss, similar to L2 loss close to zero, L1 loss away from zero.

pixel_control_loss(observations, actions, …)

Calculate n-step Q-learning loss for pixel control auxiliary task.

L2 Loss

rlax.l2_loss(predictions, targets=None)[source]

Caculates the L2 loss of predictions wrt targets.

If targets are not provided this function acts as an L2-regularizer for preds.

Note: the 0.5 term is standard in “Pattern Recognition and Machine Learning” by Bishop, but not “The Elements of Statistical Learning” by Tibshirani.

Parameters
  • predictions (Array) – a vector of arbitrary shape.

  • targets (Optional[Array]) – a vector of shape compatible with predictions.

Return type

Array

Returns

a vector of same shape of predictions.

Likelihood

rlax.likelihood(predictions, targets)[source]

Calculates the likelihood of predictions wrt targets.

Parameters
  • predictions (Array) – a vector of arbitrary shape.

  • targets (Array) – a vector of shape compatible with predictions.

Return type

Array

Returns

a vector of same shape of predictions.

Log Loss

rlax.log_loss(predictions, targets)[source]

Calculates the log loss of predictions wrt targets.

Parameters
  • predictions (Array) – a vector of probabilities of arbitrary shape.

  • targets (Array) – a vector of probabilities of shape compatible with predictions.

Return type

Array

Returns

a vector of same shape of predictions.

Huber Loss

rlax.huber_loss(x, delta=1.0)[source]

Huber loss, similar to L2 loss close to zero, L1 loss away from zero.

See “Robust Estimation of a Location Parameter” by Huber. (https://projecteuclid.org/download/pdf_1/euclid.aoms/1177703732).

Parameters
  • x (Array) – a vector of arbitrary shape.

  • delta (float) – the bounds for the huber loss transformation, defaults at 1.

Note grad(huber_loss(x)) is equivalent to grad(0.5 * clip_gradient(x)**2).

Return type

Array

Returns

a vector of same shape of x.

Pixel Control Loss

rlax.pixel_control_loss(observations, actions, action_values, discount_factor, cell_size)[source]

Calculate n-step Q-learning loss for pixel control auxiliary task.

For each pixel-based pseudo reward signal, the corresponding action-value function is trained off-policy, using Q(lambda). A discount of 0.9 is commonly used for learning the value functions.

Note that, since pseudo rewards have a spatial structure, with neighbouring cells exhibiting strong correlations, it is convenient to predict the action values for all the cells through a deconvolutional head.

See “Reinforcement Learning with Unsupervised Auxiliary Tasks” by Jaderberg, Mnih, Czarnecki et al. (https://arxiv.org/abs/1611.05397).

Parameters
  • observations (Array) – A tensor of shape [T+1, …]; is the observation shape, T the sequence length.

  • actions (Array) – A tensor, shape [T,], of the actions across each sequence.

  • action_values (Array) – A tensor, shape [T+1, H, W, N] of pixel control action values, where H, W are the number of pixel control cells/tasks, and N is the number of actions.

  • discount_factor (Union[Array, Scalar]) – discount used for learning the value function associated to the pseudo rewards; must be a scalar or a Tensor of shape [T].

  • cell_size (int) – size of the cells used to derive the pixel based pseudo-rewards.

Returns

a tensor containing the spatial loss, shape [T, H, W].

Raises

ValueError – if the shape of action_values is not compatible with that of the pseudo-rewards derived from the observations.

Distributions

categorical_cross_entropy(labels, logits)

Computes the softmax cross entropy between sets of logits and labels.

categorical_importance_sampling_ratios(…)

Compute importance sampling ratios from logits.

categorical_kl_divergence(p_logits, q_logits)

Compute the KL between two categorical distributions from their logits.

categorical_sample(key, probs)

Sample from a set of discrete probabilities.

clipped_entropy_softmax([temperature, …])

A softmax distribution with clipped entropy (1 is eq to not clipping).

epsilon_greedy([epsilon])

An epsilon-greedy distribution.

epsilon_softmax(epsilon, temperature)

An epsilon-softmax distribution.

gaussian_diagonal([sigma])

A gaussian distribution with diagonal covariance matrix.

greedy()

A greedy distribution.

multivariate_normal_kl_divergence(mu_0, …)

Compute the KL between 2 gaussian distrs with diagonal covariance matrices.

safe_epsilon_softmax(epsilon, temperature)

Tolerantly handles the temperature=0 case.

softmax([temperature])

A softmax distribution.

squashed_gaussian([sigma_min, sigma_max])

A squashed gaussian distribution with diagonal covariance matrix.

Categorical Cross Entropy

rlax.categorical_cross_entropy(labels, logits)[source]

Computes the softmax cross entropy between sets of logits and labels.

See “Deep Learning” by Goodfellow et al. (http://www.deeplearningbook.org/contents/prob.html).

Parameters
  • labels (Array) – a valid probability distribution (non-negative, sum to 1).

  • logits (Array) – unnormalized log probabilities.

Return type

Array

Returns

a scalar loss.

Categorical Importance Sampling Ratios

rlax.categorical_importance_sampling_ratios(pi_logits_t, mu_logits_t, a_t)[source]

Compute importance sampling ratios from logits.

Parameters
  • pi_logits_t (Array) – unnormalized logits at time t for the target policy.

  • mu_logits_t (Array) – unnormalized logits at time t for the behavior policy.

  • a_t (Array) – actions at time t.

Return type

Array

Returns

importance sampling ratios.

Categorical KL Divergence

rlax.categorical_kl_divergence(p_logits, q_logits, temperature=1.0)[source]

Compute the KL between two categorical distributions from their logits.

Parameters
  • p_logits (Array) – unnormalized logits for the first distribution.

  • q_logits (Array) – unnormalized logits for the second distribution.

  • temperature (float) – the temperature for the softmax distribution, defaults at 1.

Return type

Array

Returns

the kl divergence between the distributions.

Categorical Sample

rlax.categorical_sample(key, probs)[source]

Sample from a set of discrete probabilities.

Clipped Entropy Softmax

rlax.clipped_entropy_softmax(temperature=1.0, entropy_clip=1.0)[source]

A softmax distribution with clipped entropy (1 is eq to not clipping).

Epsilon Greedy

rlax.epsilon_greedy(epsilon=None)[source]

An epsilon-greedy distribution.

Epsilon Softmax

rlax.epsilon_softmax(epsilon, temperature)[source]

An epsilon-softmax distribution.

Gaussian Diagonal

rlax.gaussian_diagonal(sigma=None)[source]

A gaussian distribution with diagonal covariance matrix.

Greedy

rlax.greedy()[source]

A greedy distribution.

Multivariate Normal KL Divergence

rlax.multivariate_normal_kl_divergence(mu_0, sigma_0, mu_1, sigma_1)[source]

Compute the KL between 2 gaussian distrs with diagonal covariance matrices.

Parameters
  • mu_0 (Array) – array like of mean values for policy 0

  • sigma_0 (Numeric) – array like of std values for policy 0

  • mu_1 (Array) – array like of mean values for policy 1

  • sigma_1 (Numeric) – array like of std values for policy 1

Return type

Array

Returns

the kl divergence between the distributions.

Safe Epsilon Softmax

rlax.safe_epsilon_softmax(epsilon, temperature)[source]

Tolerantly handles the temperature=0 case.

Softmax

rlax.softmax(temperature=1.0)[source]

A softmax distribution.

Squashed Gaussian

rlax.squashed_gaussian(sigma_min=- 4, sigma_max=0.0)[source]

A squashed gaussian distribution with diagonal covariance matrix.