Values, including both state and action-values;
Values for Non-linear generalizations of the Bellman equations.
Return Distributions, aka distributional value functions;
General Value Functions, for cumulants other than the main reward;
Policies, via policy-gradients in both continuous and discrete action spaces.
Value Learning¶
|
Implements double Q-learning for categorical Q distributions. |
|
Projects a categorical distribution (z_p, p) onto a different support z_q. |
|
Implements Q-learning for categorical Q distributions. |
|
Implements TD-learning for categorical value distributions. |
|
Calculates a discounted return from a trajectory. |
|
Calculates the double Q-learning temporal difference error. |
|
Calculates the expected SARSA (SARSE) temporal difference error. |
Calculates targets for various off-policy correction algorithms. |
|
Calculates targets for various off-policy evaluation algorithms. |
|
|
Estimates a multistep truncated lambda return from a trajectory. |
|
Calculates Leaky V-Trace errors from importance weights. |
|
Calculates Leaky V-Trace errors and PG advantage from importance weights. |
|
Computes strided n-step bootstrapped return targets over a sequence. |
|
Calculates the persistent Q-learning temporal difference error. |
|
Calculates Peng’s or Watkins’ Q(lambda) temporal difference error. |
|
Calculates the Q-learning temporal difference error. |
|
Implements Expected SARSA for quantile-valued Q distributions. |
|
Implements Q-learning for quantile-valued Q distributions. |
|
Calculates the QV-learning temporal difference error. |
|
Calculates the QVMAX temporal difference error. |
|
Calculates Retrace errors. |
|
Retrace continuous. |
|
Calculates the SARSA temporal difference error. |
|
Calculates the SARSA(lambda) temporal difference error. |
|
Calculates the TD(lambda) temporal difference error. |
|
Calculates the TD-learning temporal difference error. |
|
Calculates targets for various off-policy correction algorithms. |
|
Estimates a multistep truncated lambda return from a trajectory. |
|
Calculates transformed n-step TD errors. |
|
Computes strided n-step bootstrapped return targets over a sequence. |
|
Calculates Peng’s or Watkins’ Q(lambda) temporal difference error. |
|
Calculates transformed Retrace errors. |
|
Calculates V-Trace errors from importance weights. |
|
Calculates V-Trace errors and PG advantage from importance weights. |
Categorical Double Q Learning¶
-
rlax.
categorical_double_q_learning
(q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t, q_logits_t, q_t_selector, stop_target_gradients=True)[source]¶ Implements double Q-learning for categorical Q distributions.
- See “A Distributional Perspective on Reinforcement Learning”, by
Bellemere, Dabney and Munos (https://arxiv.org/pdf/1707.06887.pdf)
and “Double Q-learning” by van Hasselt. (https://papers.nips.cc/paper/3964-double-q-learning.pdf).
- Parameters
q_atoms_tm1 (Array) – atoms of Q distribution at time t-1.
q_logits_tm1 (Array) – logits of Q distribution at time t-1.
a_tm1 (Numeric) – action index at time t-1.
r_t (Numeric) – reward at time t.
discount_t (Numeric) – discount at time t.
q_atoms_t (Array) – atoms of Q distribution at time t.
q_logits_t (Array) – logits of Q distribution at time t.
q_t_selector (Array) – selector Q-values at time t.
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Numeric
- Returns
Categorical double Q-learning loss (i.e. temporal difference error).
Categorical L2 Project¶
-
rlax.
categorical_l2_project
(z_p, probs, z_q)[source]¶ Projects a categorical distribution (z_p, p) onto a different support z_q.
The projection step minimizes an L2-metric over the cumulative distribution functions (CDFs) of the source and target distributions.
Let kq be len(z_q) and kp be len(z_p). This projection works for any support z_q, in particular kq need not be equal to kp.
See “A Distributional Perspective on RL” by Bellemare et al. (https://arxiv.org/abs/1707.06887).
- Parameters
z_p (Array) – support of distribution p.
probs (Array) – probability values.
z_q (Array) – support to project distribution (z_p, probs) onto.
- Return type
Array
- Returns
Projection of (z_p, p) onto support z_q under Cramer distance.
Categorical Q Learning¶
-
rlax.
categorical_q_learning
(q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t, q_logits_t, stop_target_gradients=True)[source]¶ Implements Q-learning for categorical Q distributions.
- See “A Distributional Perspective on Reinforcement Learning”, by
Bellemere, Dabney and Munos (https://arxiv.org/pdf/1707.06887.pdf).
- Parameters
q_atoms_tm1 (Array) – atoms of Q distribution at time t-1.
q_logits_tm1 (Array) – logits of Q distribution at time t-1.
a_tm1 (Numeric) – action index at time t-1.
r_t (Numeric) – reward at time t.
discount_t (Numeric) – discount at time t.
q_atoms_t (Array) – atoms of Q distribution at time t.
q_logits_t (Array) – logits of Q distribution at time t.
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Numeric
- Returns
Categorical Q-learning loss (i.e. temporal difference error).
Categorical TD Learning¶
-
rlax.
categorical_td_learning
(v_atoms_tm1, v_logits_tm1, r_t, discount_t, v_atoms_t, v_logits_t, stop_target_gradients=True)[source]¶ Implements TD-learning for categorical value distributions.
- See “A Distributional Perspective on Reinforcement Learning”, by
Bellemere, Dabney and Munos (https://arxiv.org/pdf/1707.06887.pdf).
- Parameters
v_atoms_tm1 (Array) – atoms of V distribution at time t-1.
v_logits_tm1 (Array) – logits of V distribution at time t-1.
r_t (Numeric) – reward at time t.
discount_t (Numeric) – discount at time t.
v_atoms_t (Array) – atoms of V distribution at time t.
v_logits_t (Array) – logits of V distribution at time t.
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Numeric
- Returns
Categorical Q learning loss (i.e. temporal difference error).
Discounted Returns¶
-
rlax.
discounted_returns
(r_t, discount_t, v_t, stop_target_gradients=False)[source]¶ Calculates a discounted return from a trajectory.
The returns are computed recursively, from G_{T-1} to G_0, according to:
Gₜ = rₜ₊₁ + γₜ₊₁ Gₜ₊₁.
See “Reinforcement Learning: An Introduction” by Sutton and Barto. (http://incompleteideas.net/sutton/book/ebook/node61.html).
- Parameters
r_t (Array) – reward sequence at time t.
discount_t (Array) – discount sequence at time t.
v_t (Array) – value sequence or scalar at time t.
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Array
- Returns
Discounted returns.
Double Q Learning¶
-
rlax.
double_q_learning
(q_tm1, a_tm1, r_t, discount_t, q_t_value, q_t_selector, stop_target_gradients=True)[source]¶ Calculates the double Q-learning temporal difference error.
See “Double Q-learning” by van Hasselt. (https://papers.nips.cc/paper/3964-double-q-learning.pdf).
- Parameters
q_tm1 (Array) – Q-values at time t-1.
a_tm1 (Numeric) – action index at time t-1.
r_t (Numeric) – reward at time t.
discount_t (Numeric) – discount at time t.
q_t_value (Array) – Q-values at time t.
q_t_selector (Array) – selector Q-values at time t.
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Numeric
- Returns
Double Q-learning temporal difference error.
Expected SARSA¶
-
rlax.
expected_sarsa
(q_tm1, a_tm1, r_t, discount_t, q_t, probs_a_t, stop_target_gradients=True)[source]¶ Calculates the expected SARSA (SARSE) temporal difference error.
See “A Theoretical and Empirical Analysis of Expected Sarsa” by Seijen, van Hasselt, Whiteson et al. (http://www.cs.ox.ac.uk/people/shimon.whiteson/pubs/vanseijenadprl09.pdf).
- Parameters
q_tm1 (Array) – Q-values at time t-1.
a_tm1 (Numeric) – action index at time t-1.
r_t (Numeric) – reward at time t.
discount_t (Numeric) – discount at time t.
q_t (Array) – Q-values at time t.
probs_a_t (Array) – action probabilities at time t.
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Numeric
- Returns
Expected SARSA temporal difference error.
General Off Policy Returns From Action Values¶
-
rlax.
general_off_policy_returns_from_action_values
(q_t, a_t, r_t, discount_t, c_t, pi_t, stop_target_gradients=False)[source]¶ Calculates targets for various off-policy correction algorithms.
Given a window of experience of length K, generated by a behaviour policy μ, for each time-step t we can estimate the return G_t from that step onwards, under some target policy π, using the rewards in the trajectory, the actions selected by μ and the action-values under π, according to equation:
Gₜ = rₜ₊₁ + γₜ₊₁ * (E[q(aₜ₊₁)] - cₜ * q(aₜ₊₁) + cₜ * Gₜ₊₁),
where, depending on the choice of c_t, the algorithm implements:
Importance Sampling c_t = π(x_t, a_t) / μ(x_t, a_t), Harutyunyan’s et al. Q(lambda) c_t = λ, Precup’s et al. Tree-Backup c_t = π(x_t, a_t), Munos’ et al. Retrace c_t = λ min(1, π(x_t, a_t) / μ(x_t, a_t)).
See “Safe and Efficient Off-Policy Reinforcement Learning” by Munos et al. (https://arxiv.org/abs/1606.02647).
- Parameters
q_t (Array) – Q-values at times [1, …, K - 1].
a_t (Array) – action index at times [1, …, K - 1].
r_t (Array) – reward at times [1, …, K - 1].
discount_t (Array) – discount at times [1, …, K - 1].
c_t (Array) – importance weights at times [1, …, K - 1].
pi_t (Array) – target policy probs at times [1, …, K - 1].
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Array
- Returns
Off-policy estimates of the generalized returns from states visited at times [0, …, K - 1].
General Off Policy Returns From Q and V¶
-
rlax.
general_off_policy_returns_from_q_and_v
(q_t, v_t, r_t, discount_t, c_t, stop_target_gradients=False)[source]¶ Calculates targets for various off-policy evaluation algorithms.
Given a window of experience of length K+1, generated by a behaviour policy μ, for each time-step t we can estimate the return G_t from that step onwards, under some target policy π, using the rewards in the trajectory, the values under π of states and actions selected by μ, according to equation:
Gₜ = rₜ₊₁ + γₜ₊₁ * (vₜ₊₁ - cₜ₊₁ * q(aₜ₊₁) + cₜ₊₁* Gₜ₊₁),
where, depending on the choice of c_t, the algorithm implements:
Importance Sampling c_t = π(x_t, a_t) / μ(x_t, a_t), Harutyunyan’s et al. Q(lambda) c_t = λ, Precup’s et al. Tree-Backup c_t = π(x_t, a_t), Munos’ et al. Retrace c_t = λ min(1, π(x_t, a_t) / μ(x_t, a_t)).
See “Safe and Efficient Off-Policy Reinforcement Learning” by Munos et al. (https://arxiv.org/abs/1606.02647).
- Parameters
q_t (Array) – Q-values under π of actions executed by μ at times [1, …, K - 1].
v_t (Array) – Values under π at times [1, …, K].
r_t (Array) – rewards at times [1, …, K].
discount_t (Array) – discounts at times [1, …, K].
c_t (Array) – weights at times [1, …, K - 1].
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Array
- Returns
Off-policy estimates of the generalized returns from states visited at times [0, …, K - 1].
Lambda Returns¶
-
rlax.
lambda_returns
(r_t, discount_t, v_t, lambda_=1.0, stop_target_gradients=False)[source]¶ Estimates a multistep truncated lambda return from a trajectory.
Given a a trajectory of length T+1, generated under some policy π, for each time-step t we can estimate a target return G_t, by combining rewards, discounts, and state values, according to a mixing parameter lambda.
The parameter lambda_ mixes the different multi-step bootstrapped returns, corresponding to accumulating k rewards and then bootstrapping using v_t.
rₜ₊₁ + γₜ₊₁ vₜ₊₁ rₜ₊₁ + γₜ₊₁ rₜ₊₂ + γₜ₊₁ γₜ₊₂ vₜ₊₂ rₜ₊₁ + γₜ₊₁ rₜ₊₂ + γₜ₊₁ γₜ₊₂ rₜ₊₂ + γₜ₊₁ γₜ₊₂ γₜ₊₃ vₜ₊₃
The returns are computed recursively, from G_{T-1} to G_0, according to:
Gₜ = rₜ₊₁ + γₜ₊₁ [(1 - λₜ₊₁) vₜ₊₁ + λₜ₊₁ Gₜ₊₁].
In the on-policy case, we estimate a return target G_t for the same policy π that was used to generate the trajectory. In this setting the parameter lambda_ is typically a fixed scalar factor. Depending on how values v_t are computed, this function can be used to construct targets for different multistep reinforcement learning updates:
TD(λ): v_t contains the state value estimates for each state under π. Q(λ): v_t = max(q_t, axis=-1), where q_t estimates the action values. Sarsa(λ): v_t = q_t[…, a_t], where q_t estimates the action values.
In the off-policy case, the mixing factor is a function of state, and different definitions of lambda implement different off-policy corrections:
Per-decision importance sampling: λₜ = λ ρₜ = λ [π(aₜ|sₜ) / μ(aₜ|sₜ)] V-trace, as instantiated in IMPALA: λₜ = min(1, ρₜ)
Note that the second option is equivalent to applying per-decision importance sampling, but using an adaptive λ(ρₜ) = min(1/ρₜ, 1), such that the effective bootstrap parameter at time t becomes λₜ = λ(ρₜ) * ρₜ = min(1, ρₜ). This is the interpretation used in the ABQ(ζ) algorithm (Mahmood 2017).
Of course this can be augmented to include an additional factor λ. For instance we could use V-trace with a fixed additional parameter λ = 0.9, by setting λₜ = 0.9 * min(1, ρₜ) or, alternatively (but not equivalently), λₜ = min(0.9, ρₜ).
Estimated return are then often used to define a td error, e.g.: ρₜ(Gₜ - vₜ).
See “Reinforcement Learning: An Introduction” by Sutton and Barto. (http://incompleteideas.net/sutton/book/ebook/node74.html).
- Parameters
r_t (Array) – sequence of rewards rₜ for timesteps t in [1, T].
discount_t (Array) – sequence of discounts γₜ for timesteps t in [1, T].
v_t (Array) – sequence of state values estimates under π for timesteps t in [1, T].
lambda – mixing parameter; a scalar or a vector for timesteps t in [1, T].
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Array
- Returns
Multistep lambda returns.
Leaky VTrace¶
-
rlax.
leaky_vtrace
(v_tm1, v_t, r_t, discount_t, rho_tm1, alpha_=1.0, lambda_=1.0, clip_rho_threshold=1.0, stop_target_gradients=True)[source]¶ Calculates Leaky V-Trace errors from importance weights.
Leaky-Vtrace is a combination of Importance sampling and V-trace, where the degree of mixing is controlled by a scalar alpha (that may be meta-learnt).
See “Self-Tuning Deep Reinforcement Learning” by Zahavy et al. (https://arxiv.org/abs/2002.12928)
- Parameters
v_tm1 (Array) – values at time t-1.
v_t (Array) – values at time t.
r_t (Array) – reward at time t.
discount_t (Array) – discount at time t.
rho_tm1 (Array) – importance weights at time t-1.
alpha – mixing parameter for Importance Sampling and V-trace.
lambda – scalar mixing parameter lambda.
clip_rho_threshold (float) – clip threshold for importance weights.
stop_target_gradients (bool) – whether or not to apply stop gradient to targets.
- Returns
Leaky V-Trace error.
N Step Bootstrapped Returns¶
-
rlax.
n_step_bootstrapped_returns
(r_t, discount_t, v_t, n, lambda_t=1.0, stop_target_gradients=False)[source]¶ Computes strided n-step bootstrapped return targets over a sequence.
The returns are computed according to the below equation iterated n times:
Gₜ = rₜ₊₁ + γₜ₊₁ [(1 - λₜ₊₁) vₜ₊₁ + λₜ₊₁ Gₜ₊₁].
When lambda_t == 1. (default), this reduces to
Gₜ = rₜ₊₁ + γₜ₊₁ * (rₜ₊₂ + γₜ₊₂ * (… * (rₜ₊ₙ + γₜ₊ₙ * vₜ₊ₙ ))).
- Parameters
r_t (Array) – rewards at times [1, …, T].
discount_t (Array) – discounts at times [1, …, T].
v_t (Array) – state or state-action values to bootstrap from at time [1, …., T].
n (int) – number of steps over which to accumulate reward before bootstrapping.
lambda_t (Numeric) – lambdas at times [1, …, T]. Shape is [], or [T-1].
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Array
- Returns
estimated bootstrapped returns at times [0, …., T-1]
Leaky VTrace TD Error and Advantage¶
-
rlax.
leaky_vtrace_td_error_and_advantage
(v_tm1, v_t, r_t, discount_t, rho_tm1, alpha=1.0, lambda_=1.0, clip_rho_threshold=1.0, clip_pg_rho_threshold=1.0, stop_target_gradients=True)[source]¶ Calculates Leaky V-Trace errors and PG advantage from importance weights.
This functions computes the Leaky V-Trace TD-errors and policy gradient Advantage terms as used by the IMPALA distributed actor-critic agent.
Leaky-Vtrace is a combination of Importance sampling and V-trace, where the degree of mixing is controlled by a scalar alpha (that may be meta-learnt).
See “Self-Tuning Deep Reinforcement Learning” by Zahavy et al. (https://arxiv.org/abs/2002.12928) and “IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor Learner Architectures” by Espeholt et al. (https://arxiv.org/abs/1802.01561)
- Parameters
v_tm1 (chex.Array) – values at time t-1.
v_t (chex.Array) – values at time t.
r_t (chex.Array) – reward at time t.
discount_t (chex.Array) – discount at time t.
rho_tm1 (chex.Array) – importance weights at time t-1.
alpha (float) – mixing the clipped importance sampling weights with unclipped ones.
lambda – scalar mixing parameter lambda.
clip_rho_threshold (float) – clip threshold for importance ratios.
clip_pg_rho_threshold (float) – clip threshold for policy gradient importance ratios.
stop_target_gradients (bool) – whether or not to apply stop gradient to targets.
- Return type
VTraceOutput
- Returns
a tuple of V-Trace error, policy gradient advantage, and estimated Q-values.
Persistent Q Learning¶
-
rlax.
persistent_q_learning
(q_tm1, a_tm1, r_t, discount_t, q_t, action_gap_scale, stop_target_gradients=True)[source]¶ Calculates the persistent Q-learning temporal difference error.
See “Increasing the Action Gap: New Operators for Reinforcement Learning” by Bellemare, Ostrovski, Guez et al. (https://arxiv.org/abs/1512.04860).
- Parameters
q_tm1 (Array) – Q-values at time t-1.
a_tm1 (Numeric) – action index at time t-1.
r_t (Numeric) – reward at time t.
discount_t (Numeric) – discount at time t.
q_t (Array) – Q-values at time t.
action_gap_scale (float) – coefficient in [0, 1] for scaling the action gap term.
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Numeric
- Returns
Persistent Q-learning temporal difference error.
Q-Lambda¶
-
rlax.
q_lambda
(q_tm1, a_tm1, r_t, discount_t, q_t, lambda_, stop_target_gradients=True)[source]¶ Calculates Peng’s or Watkins’ Q(lambda) temporal difference error.
See “Reinforcement Learning: An Introduction” by Sutton and Barto. (http://incompleteideas.net/book/ebook/node78.html).
- Parameters
q_tm1 (Array) – sequence of Q-values at time t-1.
a_tm1 (Array) – sequence of action indices at time t-1.
r_t (Array) – sequence of rewards at time t.
discount_t (Array) – sequence of discounts at time t.
q_t (Array) – sequence of Q-values at time t.
lambda – mixing parameter lambda, either a scalar (e.g. Peng’s Q(lambda)) or a sequence (e.g. Watkin’s Q(lambda)).
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Array
- Returns
Q(lambda) temporal difference error.
Q Learning¶
-
rlax.
q_learning
(q_tm1, a_tm1, r_t, discount_t, q_t, stop_target_gradients=True)[source]¶ Calculates the Q-learning temporal difference error.
See “Reinforcement Learning: An Introduction” by Sutton and Barto. (http://incompleteideas.net/book/ebook/node65.html).
- Parameters
q_tm1 (Array) – Q-values at time t-1.
a_tm1 (Numeric) – action index at time t-1.
r_t (Numeric) – reward at time t.
discount_t (Numeric) – discount at time t.
q_t (Array) – Q-values at time t.
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Numeric
- Returns
Q-learning temporal difference error.
Quantile Expected Sarsa¶
-
rlax.
quantile_expected_sarsa
(dist_q_tm1, tau_q_tm1, a_tm1, r_t, discount_t, dist_q_t, probs_a_t, huber_param=0.0, stop_target_gradients=True)[source]¶ Implements Expected SARSA for quantile-valued Q distributions.
- Parameters
dist_q_tm1 (Array) – Q distribution at time t-1.
tau_q_tm1 (Array) – Q distribution probability thresholds.
a_tm1 (Numeric) – action index at time t-1.
r_t (Numeric) – reward at time t.
discount_t (Numeric) – discount at time t.
dist_q_t (Array) – target Q distribution at time t.
probs_a_t (Array) – action probabilities at time t.
huber_param (float) – Huber loss parameter, defaults to 0 (no Huber loss).
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Numeric
- Returns
Quantile regression Expected SARSA learning loss.
Quantile Q Learning¶
-
rlax.
quantile_q_learning
(dist_q_tm1, tau_q_tm1, a_tm1, r_t, discount_t, dist_q_t_selector, dist_q_t, huber_param=0.0, stop_target_gradients=True)[source]¶ Implements Q-learning for quantile-valued Q distributions.
See “Distributional Reinforcement Learning with Quantile Regression” by Dabney et al. (https://arxiv.org/abs/1710.10044).
- Parameters
dist_q_tm1 (Array) – Q distribution at time t-1.
tau_q_tm1 (Array) – Q distribution probability thresholds.
a_tm1 (Numeric) – action index at time t-1.
r_t (Numeric) – reward at time t.
discount_t (Numeric) – discount at time t.
dist_q_t_selector (Array) – Q distribution at time t for selecting greedy action in target policy. This is separate from dist_q_t as in Double Q-Learning, but can be computed with the target network and a separate set of samples.
dist_q_t (Array) – target Q distribution at time t.
huber_param (float) – Huber loss parameter, defaults to 0 (no Huber loss).
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Numeric
- Returns
Quantile regression Q learning loss.
QV Learning¶
-
rlax.
qv_learning
(q_tm1, a_tm1, r_t, discount_t, v_t, stop_target_gradients=True)[source]¶ Calculates the QV-learning temporal difference error.
See “Two Novel On-policy Reinforcement Learning Algorithms based on TD(lambda)-methods” by Wiering and van Hasselt (https://ieeexplore.ieee.org/abstract/document/4220845.)
- Parameters
q_tm1 (Array) – Q-values at time t-1.
a_tm1 (Numeric) – action index at time t-1.
r_t (Numeric) – reward at time t.
discount_t (Numeric) – discount at time t.
v_t (Numeric) – state values at time t.
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Numeric
- Returns
QV-learning temporal difference error.
QV Max¶
-
rlax.
qv_max
(v_tm1, r_t, discount_t, q_t, stop_target_gradients=True)[source]¶ Calculates the QVMAX temporal difference error.
See “The QV Family Compared to Other Reinforcement Learning Algorithms” by Wiering and van Hasselt (2009). (http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.713.1931)
- Parameters
v_tm1 (Numeric) – state values at time t-1.
r_t (Numeric) – reward at time t.
discount_t (Numeric) – discount at time t.
q_t (Array) – Q-values at time t.
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Numeric
- Returns
QVMAX temporal difference error.
Retrace¶
-
rlax.
retrace
(q_tm1, q_t, a_tm1, a_t, r_t, discount_t, pi_t, mu_t, lambda_, eps=1e-08, stop_target_gradients=True)[source]¶ Calculates Retrace errors.
See “Safe and Efficient Off-Policy Reinforcement Learning” by Munos et al. (https://arxiv.org/abs/1606.02647).
- Parameters
q_tm1 (Array) – Q-values at time t-1.
q_t (Array) – Q-values at time t.
a_tm1 (Array) – action index at time t-1.
a_t (Array) – action index at time t.
r_t (Array) – reward at time t.
discount_t (Array) – discount at time t.
pi_t (Array) – target policy probs at time t.
mu_t (Array) – behavior policy probs at time t.
lambda – scalar mixing parameter lambda.
eps (float) – small value to add to mu_t for numerical stability.
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Array
- Returns
Retrace error.
Retrace Continuous¶
-
rlax.
retrace_continuous
(q_tm1, q_t, v_t, r_t, discount_t, log_rhos, lambda_, stop_target_gradients=True)[source]¶ Retrace continuous.
See “Safe and Efficient Off-Policy Reinforcement Learning” by Munos et al. (https://arxiv.org/abs/1606.02647).
- Parameters
q_tm1 (Array) – Q-values at times [0, …, K - 1].
q_t (Array) – Q-values evaluated at actions collected using behavior policy at times [1, …, K - 1].
v_t (Array) – Value estimates of the target policy at times [1, …, K].
r_t (Array) – reward at times [1, …, K].
discount_t (Array) – discount at times [1, …, K].
log_rhos (Array) – Log importance weight pi_target/pi_behavior evaluated at actions collected using behavior policy [1, …, K - 1].
lambda – scalar or a vector of mixing parameter lambda.
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Array
- Returns
Retrace error.
SARSA¶
-
rlax.
sarsa
(q_tm1, a_tm1, r_t, discount_t, q_t, a_t, stop_target_gradients=True)[source]¶ Calculates the SARSA temporal difference error.
See “Reinforcement Learning: An Introduction” by Sutton and Barto. (http://incompleteideas.net/book/ebook/node64.html.)
- Parameters
q_tm1 (Array) – Q-values at time t-1.
a_tm1 (Numeric) – action index at time t-1.
r_t (Numeric) – reward at time t.
discount_t (Numeric) – discount at time t.
q_t (Array) – Q-values at time t.
a_t (Numeric) – action index at time t.
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Numeric
- Returns
SARSA temporal difference error.
SARSA Lambda¶
-
rlax.
sarsa_lambda
(q_tm1, a_tm1, r_t, discount_t, q_t, a_t, lambda_, stop_target_gradients=True)[source]¶ Calculates the SARSA(lambda) temporal difference error.
See “Reinforcement Learning: An Introduction” by Sutton and Barto. (http://incompleteideas.net/book/ebook/node77.html).
- Parameters
q_tm1 (Array) – sequence of Q-values at time t-1.
a_tm1 (Array) – sequence of action indices at time t-1.
r_t (Array) – sequence of rewards at time t.
discount_t (Array) – sequence of discounts at time t.
q_t (Array) – sequence of Q-values at time t.
a_t (Array) – sequence of action indices at time t.
lambda – mixing parameter lambda, either a scalar or a sequence.
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Array
- Returns
SARSA(lambda) temporal difference error.
TD Lambda¶
-
rlax.
td_lambda
(v_tm1, r_t, discount_t, v_t, lambda_, stop_target_gradients=True)[source]¶ Calculates the TD(lambda) temporal difference error.
See “Reinforcement Learning: An Introduction” by Sutton and Barto. (http://incompleteideas.net/book/ebook/node74.html).
- Parameters
v_tm1 (Array) – sequence of state values at time t-1.
r_t (Array) – sequence of rewards at time t.
discount_t (Array) – sequence of discounts at time t.
v_t (Array) – sequence of state values at time t.
lambda – mixing parameter lambda, either a scalar or a sequence.
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Array
- Returns
TD(lambda) temporal difference error.
TD Learning¶
-
rlax.
td_learning
(v_tm1, r_t, discount_t, v_t, stop_target_gradients=True)[source]¶ Calculates the TD-learning temporal difference error.
See “Learning to Predict by the Methods of Temporal Differences” by Sutton. (https://link.springer.com/article/10.1023/A:1022633531479).
- Parameters
v_tm1 (Numeric) – state values at time t-1.
r_t (Numeric) – reward at time t.
discount_t (Numeric) – discount at time t.
v_t (Numeric) – state values at time t.
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Numeric
- Returns
TD-learning temporal difference error.
Transformed General Off Policy Returns from Action Values¶
-
rlax.
transformed_general_off_policy_returns_from_action_values
(q_t, a_t, r_t, discount_t, c_t, pi_t, stop_target_gradients=False)[source]¶ Calculates targets for various off-policy correction algorithms.
Given a window of experience of length K, generated by a behaviour policy μ, for each time-step t we can estimate the return G_t from that step onwards, under some target policy π, using the rewards in the trajectory, the actions selected by μ and the action-values under π, according to equation:
Gₜ = rₜ₊₁ + γₜ₊₁ * (E[q(aₜ₊₁)] - cₜ * q(aₜ₊₁) + cₜ * Gₜ₊₁),
where, depending on the choice of c_t, the algorithm implements:
Importance Sampling c_t = π(x_t, a_t) / μ(x_t, a_t), Harutyunyan’s et al. Q(lambda) c_t = λ, Precup’s et al. Tree-Backup c_t = π(x_t, a_t), Munos’ et al. Retrace c_t = λ min(1, π(x_t, a_t) / μ(x_t, a_t)).
See “Safe and Efficient Off-Policy Reinforcement Learning” by Munos et al. (https://arxiv.org/abs/1606.02647).
- Parameters
q_t (Array) – Q-values at times [1, …, K - 1].
a_t (Array) – action index at times [1, …, K - 1].
r_t (Array) – reward at times [1, …, K - 1].
discount_t (Array) – discount at times [1, …, K - 1].
c_t (Array) – importance weights at times [1, …, K - 1].
pi_t (Array) – target policy probs at times [1, …, K - 1].
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Array
- Returns
Off-policy estimates of the generalized returns from states visited at times [0, …, K - 1].
Transformed Lambda Returns¶
-
rlax.
transformed_lambda_returns
(r_t, discount_t, v_t, lambda_=1.0, stop_target_gradients=False)[source]¶ Estimates a multistep truncated lambda return from a trajectory.
Given a a trajectory of length T+1, generated under some policy π, for each time-step t we can estimate a target return G_t, by combining rewards, discounts, and state values, according to a mixing parameter lambda.
The parameter lambda_ mixes the different multi-step bootstrapped returns, corresponding to accumulating k rewards and then bootstrapping using v_t.
rₜ₊₁ + γₜ₊₁ vₜ₊₁ rₜ₊₁ + γₜ₊₁ rₜ₊₂ + γₜ₊₁ γₜ₊₂ vₜ₊₂ rₜ₊₁ + γₜ₊₁ rₜ₊₂ + γₜ₊₁ γₜ₊₂ rₜ₊₂ + γₜ₊₁ γₜ₊₂ γₜ₊₃ vₜ₊₃
The returns are computed recursively, from G_{T-1} to G_0, according to:
Gₜ = rₜ₊₁ + γₜ₊₁ [(1 - λₜ₊₁) vₜ₊₁ + λₜ₊₁ Gₜ₊₁].
In the on-policy case, we estimate a return target G_t for the same policy π that was used to generate the trajectory. In this setting the parameter lambda_ is typically a fixed scalar factor. Depending on how values v_t are computed, this function can be used to construct targets for different multistep reinforcement learning updates:
TD(λ): v_t contains the state value estimates for each state under π. Q(λ): v_t = max(q_t, axis=-1), where q_t estimates the action values. Sarsa(λ): v_t = q_t[…, a_t], where q_t estimates the action values.
In the off-policy case, the mixing factor is a function of state, and different definitions of lambda implement different off-policy corrections:
Per-decision importance sampling: λₜ = λ ρₜ = λ [π(aₜ|sₜ) / μ(aₜ|sₜ)] V-trace, as instantiated in IMPALA: λₜ = min(1, ρₜ)
Note that the second option is equivalent to applying per-decision importance sampling, but using an adaptive λ(ρₜ) = min(1/ρₜ, 1), such that the effective bootstrap parameter at time t becomes λₜ = λ(ρₜ) * ρₜ = min(1, ρₜ). This is the interpretation used in the ABQ(ζ) algorithm (Mahmood 2017).
Of course this can be augmented to include an additional factor λ. For instance we could use V-trace with a fixed additional parameter λ = 0.9, by setting λₜ = 0.9 * min(1, ρₜ) or, alternatively (but not equivalently), λₜ = min(0.9, ρₜ).
Estimated return are then often used to define a td error, e.g.: ρₜ(Gₜ - vₜ).
See “Reinforcement Learning: An Introduction” by Sutton and Barto. (http://incompleteideas.net/sutton/book/ebook/node74.html).
- Parameters
r_t (Array) – sequence of rewards rₜ for timesteps t in [1, T].
discount_t (Array) – sequence of discounts γₜ for timesteps t in [1, T].
v_t (Array) – sequence of state values estimates under π for timesteps t in [1, T].
lambda – mixing parameter; a scalar or a vector for timesteps t in [1, T].
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Array
- Returns
Multistep lambda returns.
Transformed N Step Q Learning¶
-
rlax.
transformed_n_step_q_learning
(q_tm1, a_tm1, target_q_t, a_t, r_t, discount_t, n, stop_target_gradients=True, tx_pair=TxPair(apply=<function identity>, apply_inv=<function identity>))[source]¶ Calculates transformed n-step TD errors.
See “Recurrent Experience Replay in Distributed Reinforcement Learning” by Kapturowski et al. (https://openreview.net/pdf?id=r1lyTjAqYX).
- Parameters
q_tm1 (Array) – Q-values at times [0, …, T - 1].
a_tm1 (Array) – action index at times [0, …, T - 1].
target_q_t (Array) – target Q-values at time [1, …, T].
a_t (Array) – action index at times [[1, … , T]] used to select target q-values to bootstrap from; max(target_q_t) for normal Q-learning, max(q_t) for double Q-learning.
r_t (Array) – reward at times [1, …, T].
discount_t (Array) – discount at times [1, …, T].
n (int) – number of steps over which to accumulate reward before bootstrapping.
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
tx_pair (TxPair) – TxPair of value function transformation and its inverse.
- Return type
Array
- Returns
Transformed N-step TD error.
Transformed N Step Returns¶
-
rlax.
transformed_n_step_returns
(r_t, discount_t, v_t, n, lambda_t=1.0, stop_target_gradients=False)[source]¶ Computes strided n-step bootstrapped return targets over a sequence.
The returns are computed according to the below equation iterated n times:
Gₜ = rₜ₊₁ + γₜ₊₁ [(1 - λₜ₊₁) vₜ₊₁ + λₜ₊₁ Gₜ₊₁].
When lambda_t == 1. (default), this reduces to
Gₜ = rₜ₊₁ + γₜ₊₁ * (rₜ₊₂ + γₜ₊₂ * (… * (rₜ₊ₙ + γₜ₊ₙ * vₜ₊ₙ ))).
- Parameters
r_t (Array) – rewards at times [1, …, T].
discount_t (Array) – discounts at times [1, …, T].
v_t (Array) – state or state-action values to bootstrap from at time [1, …., T].
n (int) – number of steps over which to accumulate reward before bootstrapping.
lambda_t (Numeric) – lambdas at times [1, …, T]. Shape is [], or [T-1].
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Array
- Returns
estimated bootstrapped returns at times [0, …., T-1]
Transformed Q Lambda¶
-
rlax.
transformed_q_lambda
(q_tm1, a_tm1, r_t, discount_t, q_t, lambda_, stop_target_gradients=True, tx_pair=TxPair(apply=<function identity>, apply_inv=<function identity>))[source]¶ Calculates Peng’s or Watkins’ Q(lambda) temporal difference error.
See “General non-linear Bellman equations” by van Hasselt et al. (https://arxiv.org/abs/1907.03687).
- Parameters
q_tm1 (Array) – sequence of Q-values at time t-1.
a_tm1 (Array) – sequence of action indices at time t-1.
r_t (Array) – sequence of rewards at time t.
discount_t (Array) – sequence of discounts at time t.
q_t (Array) – sequence of Q-values at time t.
lambda – mixing parameter lambda, either a scalar (e.g. Peng’s Q(lambda)) or a sequence (e.g. Watkin’s Q(lambda)).
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
tx_pair (TxPair) – TxPair of value function transformation and its inverse.
- Return type
Array
- Returns
Q(lambda) temporal difference error.
Transformed Retrace
-
rlax.
transformed_retrace
(q_tm1, q_t, a_tm1, a_t, r_t, discount_t, pi_t, mu_t, lambda_, eps=1e-08, stop_target_gradients=True, tx_pair=TxPair(apply=<function identity>, apply_inv=<function identity>))[source]¶ Calculates transformed Retrace errors.
See “Recurrent Experience Replay in Distributed Reinforcement Learning” by Kapturowski et al. (https://openreview.net/pdf?id=r1lyTjAqYX).
- Parameters
q_tm1 (Array) – Q-values at time t-1.
q_t (Array) – Q-values at time t.
a_tm1 (Array) – action index at time t-1.
a_t (Array) – action index at time t.
r_t (Array) – reward at time t.
discount_t (Array) – discount at time t.
pi_t (Array) – target policy probs at time t.
mu_t (Array) – behavior policy probs at time t.
lambda – scalar mixing parameter lambda.
eps (float) – small value to add to mu_t for numerical stability.
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
tx_pair (TxPair) – TxPair of value function transformation and its inverse.
- Return type
Array
- Returns
Transformed Retrace error.
Truncated Generalized Advantage Estimation¶
-
rlax.
truncated_generalized_advantage_estimation
(r_t, discount_t, lambda_, values, stop_target_gradients=False)[source]¶ Computes truncated generalized advantage estimates for a sequence length k.
The advantages are computed in a backwards fashion according to the equation: Âₜ = δₜ + (γλ) * δₜ₊₁ + … + … + (γλ)ᵏ⁻ᵗ⁺¹ * δₖ₋₁ where δₜ = rₜ₊₁ + γₜ₊₁ * v(sₜ₊₁) - v(sₜ).
See Proximal Policy Optimization Algorithms, Schulman et al.: https://arxiv.org/abs/1707.06347
Note: This paper uses a different notation than the RLax standard
convention that follows Sutton & Barto. We use rₜ₊₁ to denote the reward received after acting in state sₜ, while the PPO paper uses rₜ.
- Parameters
r_t (Array) – Sequence of rewards at times [1, k]
discount_t (Array) – Sequence of discounts at times [1, k]
lambda – Mixing parameter; a scalar or sequence of lambda_t at times [1, k]
values (Array) – Sequence of values under π at times [0, k]
stop_target_gradients (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Array
- Returns
Multistep truncated generalized advantage estimation at times [0, k-1].
VTrace¶
-
rlax.
vtrace
(v_tm1, v_t, r_t, discount_t, rho_tm1, lambda_=1.0, clip_rho_threshold=1.0, stop_target_gradients=True)[source]¶ Calculates V-Trace errors from importance weights.
V-trace computes TD-errors from multistep trajectories by applying off-policy corrections based on clipped importance sampling ratios.
See “IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor Learner Architectures” by Espeholt et al. (https://arxiv.org/abs/1802.01561).
- Parameters
v_tm1 (Array) – values at time t-1.
v_t (Array) – values at time t.
r_t (Array) – reward at time t.
discount_t (Array) – discount at time t.
rho_tm1 (Array) – importance sampling ratios at time t-1.
lambda – scalar mixing parameter lambda.
clip_rho_threshold (float) – clip threshold for importance weights.
stop_target_gradients (bool) – whether or not to apply stop gradient to targets.
- Return type
Array
- Returns
V-Trace error.
Policy Optimization¶
|
Computes the clipped surrogate policy gradient loss. |
|
Calculates the deterministic policy gradient (DPG) loss. |
|
Calculates the entropy regularization loss. |
|
|
|
Implements the MPO loss with a KL bound. |
Computes the weights and temperature loss for MPO. |
|
|
Calculates the policy gradient loss. |
|
Computes the QPG (Q-based Policy Gradient) loss. |
|
Computes the RMPG (Regret Matching Policy Gradient) loss. |
|
Computes the RPG (Regret Policy Gradient) loss. |
Clipped Surrogate PG Loss¶
-
rlax.
clipped_surrogate_pg_loss
(prob_ratios_t, adv_t, epsilon, use_stop_gradient=True)[source]¶ Computes the clipped surrogate policy gradient loss.
L_clipₜ(θ) = - min(rₜ(θ)Âₜ, clip(rₜ(θ), 1-ε, 1+ε)Âₜ)
Where rₜ(θ) = π_θ(aₜ| sₜ) / π_θ_old(aₜ| sₜ) and Âₜ are the advantages.
See Proximal Policy Optimization Algorithms, Schulman et al.: https://arxiv.org/abs/1707.06347
- Parameters
prob_ratios_t (Array) – Ratio of action probabilities for actions a_t: rₜ(θ) = π_θ(aₜ| sₜ) / π_θ_old(aₜ| sₜ)
adv_t (Array) – the observed or estimated advantages from executing actions a_t.
epsilon (Scalar) – Scalar value corresponding to how much to clip the objecctive.
use_stop_gradient – bool indicating whether or not to apply stop gradient to advantages.
- Return type
Array
- Returns
- Loss whose gradient corresponds to a clipped surrogate policy gradient
update.
Compute Parametric KL Penalty and Dual Loss¶
DPG Loss¶
-
rlax.
dpg_loss
(a_t, dqda_t, dqda_clipping=None, use_stop_gradient=True)[source]¶ Calculates the deterministic policy gradient (DPG) loss.
See “Deterministic Policy Gradient Algorithms” by Silver, Lever, Heess, Degris, Wierstra, Riedmiller (http://proceedings.mlr.press/v32/silver14.pdf).
- Parameters
a_t (Array) – continuous-valued action at time t.
dqda_t (Array) – gradient of Q(s,a) wrt. a, evaluated at time t.
dqda_clipping (Optional[Scalar]) – clips the gradient to have norm <= dqda_clipping.
use_stop_gradient (bool) – bool indicating whether or not to apply stop gradient to targets.
- Return type
Array
- Returns
DPG loss.
Entropy Loss¶
-
rlax.
entropy_loss
(logits_t, w_t)[source]¶ Calculates the entropy regularization loss.
See “Function Optimization using Connectionist RL Algorithms” by Williams. (https://www.tandfonline.com/doi/abs/10.1080/09540099108946587)
- Parameters
logits_t (Array) – a sequence of unnormalized action preferences.
w_t (Array) – a per timestep weighting for the loss.
- Return type
Array
- Returns
Entropy loss.
Lagrange Penalty¶
MPO Compute Weights and Temperature Loss¶
-
rlax.
mpo_compute_weights_and_temperature_loss
(sample_q_values, temperature_constraint, projection_operator, sample_axis=0)[source]¶ Computes the weights and temperature loss for MPO.
The E-Step computes a non-parameteric sample-based approximation of the current policy by reweighting the state-action value function.
Here, we compute this nonparametric policy and optimize the temperature parameter used in the reweighting.
- Parameters
sample_q_values (Array) – An array of shape E* + a sample axis inserted at sample_axis containing the q function values evaluated on the sampled actions.
temperature_constraint (LagrangePenalty) – Lagrange constraint for the E-step temperature optimization.
projection_operator (Callable[[Numeric], Numeric]) – Function to project temperature into the positive range.
sample_axis (int) – Axis in sample_q_values containing sampled actions.
- Return type
Tuple[Array, Array, Scalar]
- Returns
The temperature loss, normalized weights and number of actions samples per state.
MPO Loss¶
-
rlax.
mpo_loss
(sample_log_probs, sample_q_values, temperature_constraint, kl_constraints, projection_operator=functools.partial(<CompiledFunction of <function clip>>, a_min=1e-10), policy_loss_weight=1.0, temperature_loss_weight=1.0, kl_loss_weight=1.0, alpha_loss_weight=1.0, sample_axis=0, use_stop_gradient=True)[source]¶ Implements the MPO loss with a KL bound.
This loss implements the MPO algorithm for policies with a bound for the KL between the current and target policy.
Note: This is a per-example loss which works on any shape inputs as long as they are consistent. We denote this shape E* for ease of reference. Args sample_log_probs and sample_q_values are shape E + an extra sample axis that contains the sampled actions’ log probs and q values respectively. For example, if sample_axis = 0, the shapes expected will be [S, E*]. Or if E* = [T, B] and sample_axis = 1, the shapes expected will be [T, S, B].
- Parameters
sample_log_probs (Array) – An array of shape E* + a sample axis inserted at sample_axis containing the log probabilities of the sampled actions under the current policy.
sample_q_values (Array) – An array of shape E* + a sample axis inserted at sample_axis containing the q function values evaluated on the sampled actions.
temperature_constraint (LagrangePenalty) – Lagrange constraint for the E-step temperature optimization.
kl_constraints (Sequence[Tuple[Array, LagrangePenalty]]) – KL and variables for applying Lagrangian penalties to bound them in the M-step, KLs are [E*, A?]. Here A is the action dimension in the case of per-dimension KL constraints.
projection_operator (Callable[[Numeric], Numeric]) – Function to project dual variables (temperature and kl constraint alphas) into the positive range.
policy_loss_weight (float) – Weight for the policy loss.
temperature_loss_weight (float) – Weight for the temperature loss.
kl_loss_weight (float) – Weight for the KL loss.
alpha_loss_weight (float) – Weight for the alpha loss.
sample_axis (int) – Axis in sample_log_probs and sample_q_values that contains the sampled actions’ log probs and q values respectively. For example, if sample_axis = 0, the shapes expected will be [S, E*]. Or if E* = [T, B] and sample_axis = 1, the shapes expected will be [T, S, B].
use_stop_gradient (bool) – bool indicating whether or not to apply stop gradient.
- Return type
Tuple[Array, MpoOutputs]
- Returns
Per example loss with shape E*, and additional data including the components of this loss and the normalized weights in the AdditionalOutputs.
Policy Gradient Loss¶
-
rlax.
policy_gradient_loss
(logits_t, a_t, adv_t, w_t, use_stop_gradient=True)[source]¶ Calculates the policy gradient loss.
See “Simple Gradient-Following Algorithms for Connectionist RL” by Williams. (http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf)
- Parameters
logits_t (Array) – a sequence of unnormalized action preferences.
a_t (Array) – a sequence of actions sampled from the preferences logits_t.
adv_t (Array) – the observed or estimated advantages from executing actions a_t.
w_t (Array) – a per timestep weighting for the loss.
use_stop_gradient (bool) – bool indicating whether or not to apply stop gradient to advantages.
- Return type
Array
- Returns
Loss whose gradient corresponds to a policy gradient update.
QPG Loss¶
-
rlax.
qpg_loss
(logits_t, q_t, use_stop_gradient=True)[source]¶ Computes the QPG (Q-based Policy Gradient) loss.
See “Actor-Critic Policy Optimization in Partially Observable Multiagent Environments” by Srinivasan, Lanctot (https://arxiv.org/abs/1810.09026).
- Parameters
logits_t (Array) – a sequence of unnormalized action preferences.
q_t (Array) – the observed or estimated action value from executing actions a_t at time t.
use_stop_gradient (bool) – bool indicating whether or not to apply stop gradient to advantages.
- Return type
Array
- Returns
QPG Loss.
RM Loss¶
-
rlax.
rm_loss
(logits_t, q_t, use_stop_gradient=True)[source]¶ Computes the RMPG (Regret Matching Policy Gradient) loss.
The gradient of this loss adapts the Regret Matching rule by weighting the standard PG update with thresholded regret.
See “Actor-Critic Policy Optimization in Partially Observable Multiagent Environments” by Srinivasan, Lanctot (https://arxiv.org/abs/1810.09026).
- Parameters
logits_t (Array) – a sequence of unnormalized action preferences.
q_t (Array) – the observed or estimated action value from executing actions a_t at time t.
use_stop_gradient (bool) – bool indicating whether or not to apply stop gradient to advantages.
- Return type
Array
- Returns
RM Loss.
RPG Loss¶
-
rlax.
rpg_loss
(logits_t, q_t, use_stop_gradient=True)[source]¶ Computes the RPG (Regret Policy Gradient) loss.
The gradient of this loss adapts the Regret Matching rule by weighting the standard PG update with regret.
See “Actor-Critic Policy Optimization in Partially Observable Multiagent Environments” by Srinivasan, Lanctot (https://arxiv.org/abs/1810.09026).
- Parameters
logits_t (Array) – a sequence of unnormalized action preferences.
q_t (Array) – the observed or estimated action value from executing actions a_t at time t.
use_stop_gradient (bool) – bool indicating whether or not to apply stop gradient to advantages.
- Return type
Array
- Returns
RPG Loss.
MPO Compute Weights and Temperature Loss¶
-
rlax.
vmpo_compute_weights_and_temperature_loss
(advantages, restarting_weights, importance_weights, temperature_constraint, projection_operator, top_k_fraction, axis_name=None, use_stop_gradient=True)[source]¶ Computes the weights and temperature loss for V-MPO.
- Parameters
advantages (Array) – Advantages for the E-step. Shape E*.
restarting_weights (Array) – Restarting weights, 0 means that this step is the start of a new episode and we ignore losses at this step because the agent cannot influence these. Shape E*.
importance_weights (Array) – Optional importance weights. Shape E*
temperature_constraint (LagrangePenalty) – Lagrange constraint for the E-step temperature optimization.
projection_operator (Callable[[Numeric], Numeric]) – Function to project dual variables (temperature and kl constraint alphas) into the positive range.
top_k_fraction (float) – Fraction of samples to use in the E-step.
axis_name (Optional[str]) – Optional axis name for pmap or ‘vmap’. If None, computations are performed locally on each device.
use_stop_gradient (bool) – bool indicating whether or not to apply stop gradient.
- Return type
Tuple[Scalar, Array, Scalar]
- Returns
The temperature loss, normalized weights and number of samples used.
VMPO Loss¶
-
rlax.
vmpo_loss
(sample_log_probs, advantages, temperature_constraint, kl_constraints, projection_operator=functools.partial(<CompiledFunction of <function clip>>, a_min=1e-10), restarting_weights=None, importance_weights=None, top_k_fraction=0.5, policy_loss_weight=1.0, temperature_loss_weight=1.0, kl_loss_weight=1.0, alpha_loss_weight=1.0, axis_name=None, use_stop_gradient=True)[source]¶ Calculates the V-MPO policy improvement loss.
Note: This is a per-example loss which works on any shape inputs as long as they are consistent. We denote the shape of the examples E* for ease of reference.
- Parameters
sample_log_probs (Array) – Log probabilities of actions for each example. Shape E*.
advantages (Array) – Advantages for the E-step. Shape E*.
temperature_constraint (LagrangePenalty) – Lagrange constraint for the E-step temperature optimization.
kl_constraints (Sequence[Tuple[Array, LagrangePenalty]]) – KL and variables for applying Lagrangian penalties to bound them in the M-step, KLs are E* or [E*, A]. Here A is the action dimension in the case of per-dimension KL constraints.
projection_operator (Callable[[Numeric], Numeric]) – Function to project dual variables (temperature and kl constraint alphas) into the positive range.
restarting_weights (Optional[Array]) – Optional restarting weights, shape E*, 0 means that this step is the start of a new episode and we ignore losses at this step because the agent cannot influence these.
importance_weights (Optional[Array]) – Optional importance weights, shape E*.
top_k_fraction (float) – Fraction of samples to use in the E-step.
policy_loss_weight (float) – Weight for the policy loss.
temperature_loss_weight (float) – Weight for the temperature loss.
kl_loss_weight (float) – Weight for the KL loss.
alpha_loss_weight (float) – Weight for the alpha loss.
axis_name (Optional[str]) – Optional axis name for pmap. If None, computations are performed locally on each device.
use_stop_gradient (bool) – bool indicating whether or not to apply stop gradient.
- Return type
Tuple[Array, MpoOutputs]
- Returns
Per example loss with same shape E* as array inputs, and additional data including the components of this loss and the normalized weights in the AdditionalOutputs.
Exploration¶
|
Returns discrete actions with noise drawn from a Dirichlet distribution. |
|
Returns continuous action with noise drawn from a Gaussian distribution. |
|
Returns continuous action with noise from Ornstein-Uhlenbeck process. |
Compute intrinsic rewards for exploration via episodic memory. |
Add Dirichlet Noise¶
-
rlax.
add_dirichlet_noise
(key, prior, dirichlet_alpha, dirichlet_fraction)[source]¶ Returns discrete actions with noise drawn from a Dirichlet distribution.
See “Mastering the Game of Go without Human Knowledge” by Silver et. al. 2017 (https://discovery.ucl.ac.uk/id/eprint/10045895/1/agz_unformatted_nature.pdf), “A General Reinforcement Learning Algorithm that Masters Chess, Shogi and Go Through Self-Play” by Silver et. al. 2018 (http://airesearch.com/wp-content/uploads/2016/01/deepmind-mastering-go.pdf), and “Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model” by Schrittwieser et. al., 2019 (https://arxiv.org/abs/1911.08265).
The AlphaZero family of algorithms adds noise sampled from a symmetric Dirichlet distribution to the prior policy generated by MCTS. Because the agent then samples from this new, noisy prior over actions, this encourages better exploration of the root node’s children.
Specifically, this computes:
noise ~ Dirichlet(alpha) noisy_prior = (1 - fraction) * prior + fraction * noise
Note that alpha is a single float to draw from a symmetric Dirichlet.
For reference values, AlphaZero uses 0.3, 0.15, 0.03 for Chess, Shogi, and Go respectively, and MuZero uses 0.25 for Atari.
- Parameters
key (Array) – a key from jax.random.
prior (Array) – 2-dim continuous prior policy vector of shapes [B, N], for B batch size and N num_actions.
dirichlet_alpha (float) – concentration parameter to parametrize Dirichlet distribution.
dirichlet_fraction (float) – float from 0 to 1 interpolating between using only the prior policy or just the noise.
- Return type
Array
- Returns
noisy action, of the same shape as input action.
Add Gaussian Noise¶
-
rlax.
add_gaussian_noise
(key, action, stddev)[source]¶ Returns continuous action with noise drawn from a Gaussian distribution.
- Parameters
key (Array) – a key from jax.random.
action (Array) – continuous action scalar or vector.
stddev (float) – standard deviation of noise distribution.
- Return type
Array
- Returns
noisy action, of the same shape as input action.
Add Ornstein Uhlenbeck Noise¶
-
rlax.
add_ornstein_uhlenbeck_noise
(key, action, noise_tm1, damping, stddev)[source]¶ Returns continuous action with noise from Ornstein-Uhlenbeck process.
See “On the theory of Brownian Motion” by Uhlenbeck and Ornstein. (https://journals.aps.org/pr/abstract/10.1103/PhysRev.36.823).
- Parameters
key (Array) – a key from jax.random.
action (Array) – continuous action scalar or vector.
noise_tm1 (Array) – noise sampled from OU process in previous timestep.
damping (float) – parameter for controlling autocorrelation of OU process.
stddev (float) – standard deviation of noise distribution.
- Return type
Array
- Returns
noisy action, of the same shape as input action.
Episodic Memory Intrinsic Rewards¶
-
rlax.
episodic_memory_intrinsic_rewards
(embeddings, num_neighbors, reward_scale, intrinsic_reward_state=None, constant=0.001, epsilon=0.0001, cluster_distance=0.008, max_similarity=8.0, max_memory_size=30000)[source]¶ Compute intrinsic rewards for exploration via episodic memory.
This method is adopted from the intrinsic reward computation used in “Never Give Up: Learning Directed Exploration Strategies” by Puigdomènech Badia et al., (2020) (https://arxiv.org/abs/2003.13350) and “Agent57: Outperforming the Atari Human Benchmark” by Puigdomènech Badia et al., (2020) (https://arxiv.org/abs/2002.06038).
From an embedding, we compute the intra-episode intrinsic reward with respect to a pre-existing set of embeddings.
NOTE: For this function to be jittable, static_argnums=[1,] must be passed, as the internal jax.lax.top_k(neg_distances, num_neighbors) computation in knn_query cannot be jitted with a dynamic num_neighbors that is passed as an argument.
- Parameters
embeddings (Array) – Array, shaped [M, D] for number of new state embeddings M and feature dim D.
num_neighbors (int) – int for K neighbors used in kNN query
reward_scale (float) – The β term used in the Agent57 paper to scale the reward.
intrinsic_reward_state (Optional[IntrinsicRewardState]) –
An IntrinsicRewardState namedtuple, containing: - memory: Array; an array of previous memories within an episode padded
with zeros up to max_memory_size.
- next_memory_index: The index in the static memory array to add next
embeddings at. Is updated in a ring buffer fashion.
distance_sum: Scalar, Optional; running sum of total negative squared distances computed by consecutive kNN queries used to compute mean distance.
distance_count: Scalar, Optional; running count of total negative squared distances computed by consecutive kNN queries used to compute mean distance.
NOTE- On (only) the first call to episodic_memory_intrinsic_rewards, the intrinsic_reward_state is optional, if None is given, an IntrinsicRewardState will be initialized with default parameters, specifically, the memory will be initialized to an array of jnp.inf of shape [max_memory_size x feature dim D], and default values of 0 will be provided for next_memory_index, distance_sum, and distance_count.
constant (float) – float; small constant used for numerical stability used during normalizing distances.
epsilon (float) – float; small constant used for numerical stability when computing kernel output.
cluster_distance (float) – float; the ξ term used in the Agent57 paper to bound the distance rate used in the kernel computation.
max_similarity (float) – float; max limit of similarity; used to zero rewards when similarity between memories is too high to be considered ‘useful’ for an agent.
max_memory_size (int) – int; the maximum number of memories to store. Note that performance will be marginally faster if max_memory_size is an exact multiple of M (the number of embeddings to add to memory per call to episodic_memory_intrinsic_reward).
- Returns
- Array, shaped [M, 1]; Intrinsic reward for each embedding computed
by using similarity measure to memories.
- intrinsic_reward_state: An IntrinsicRewardState namedtuple, containing:
- memory: Array; an array of previous memories within an episode padded
with zeros up to max_memory_size.
- next_memory_index: The index in the static memory array to add next
embeddings at. Is updated in a ring buffer fashion.
distance_sum: Scalar, Optional; running sum of total negative squared distances computed by consecutive kNN queries used to compute mean distance.
distance_count: Scalar, Optional; running count of total negative squared distances computed by consecutive kNN queries used to compute mean distance.
- Return type
reward
KNN Query¶
Utilities¶
|
Helper for summing over elements in an array and over devices. |
|
Index into the last dimension of a tensor, preserving all others dims. |
|
Ensures that source is compatible with target for broadcasting. |
|
Returns a one-hot version of indices. |
|
Embed each of the (observation, action, reward) inputs & concatenate. |
|
Map a function over a list of identical nested structures. |
|
Select either one of two identical nested structs based on condition. |
|
Generate random keys for each leaf in a tree. |
|
Splits a tree of arrays into an array of trees avoiding data copying. |
|
Checks whether to update the params and returns the correct params. |
|
Incrementally update all elements from a nested struct. |
|
Periodically switch all elements from a nested struct with new elements. |
All Sum¶
Batched Index¶
-
rlax.
batched_index
(values, indices, keepdims=False)[source]¶ Index into the last dimension of a tensor, preserving all others dims.
- Parameters
values (Array) – a tensor of shape […, D],
indices (Array) – indices of shape […].
keepdims (bool) – whether to keep the final dimension.
- Return type
Array
- Returns
a tensor of shape […] or […, 1].
LHS Broadcast¶
One Hot¶
-
rlax.
one_hot
(indices, num_classes, dtype=<class 'jax._src.numpy.lax_numpy.float32'>)[source]¶ Returns a one-hot version of indices.
- Parameters
indices – A tensor of indices.
num_classes – Number of classes in the one-hot dimension.
dtype – The dtype.
- Returns
- The one-hot tensor. If indices’ shape is [A, B, …], shape is
[A, B, …, num_classes].
Embed OAR¶
Tree Map Zipped¶
-
rlax.
tree_map_zipped
(fn, nests)[source]¶ Map a function over a list of identical nested structures.
- Parameters
fn (Callable[.., Any]) – the function to map; must have arity equal to len(list_of_nests).
nests (Sequence[Any]) – a list of identical nested structures.
- Returns
a nested structure whose leaves are outputs of applying fn.
Tree Select¶
-
rlax.
tree_select
(pred, on_true, on_false)[source]¶ Select either one of two identical nested structs based on condition.
- Parameters
pred (Array) – a boolean condition.
on_true (Any) – an arbitrary nested structure.
on_false (Any) – a nested structure identical to on_true.
- Returns
the selected nested structure.
Tree Split Key¶
Tree Split Leaves¶
-
rlax.
tree_split_leaves
(tree_like, axis=0, keepdim=False)[source]¶ Splits a tree of arrays into an array of trees avoiding data copying.
Note: jax.numpy.DeviceArray’s data gets copied.
- Parameters
tree_like (Any) – a nested object with leaves to split.
axis (int) – an axis for splitting.
keepdim (bool) – a bool indicating whether to keep axis dimension.
- Returns
A tuple of size(axis) trees containing results of splitting.
Conditional Update¶
Incremental Update¶
General Value Functions¶
|
Calculates cumulants for pixel control tasks from an observation sequence. |
|
Calculates cumulants for feature control tasks from a sequence of features. |
Pixel Control Rewards¶
-
rlax.
pixel_control_rewards
(observations, cell_size)[source]¶ Calculates cumulants for pixel control tasks from an observation sequence.
The observations are first split in a grid of KxK cells. For each cell a distinct pseudo reward is computed as the average absolute change in pixel intensity across all pixels in the cell. The change in intensity is averaged across both pixels and channels (e.g. RGB).
The observations provided to this function should be cropped suitably, to ensure that the observations’ height and width are a multiple of cell_size. The values of the observations tensor should be rescaled to [0, 1].
See “Reinforcement Learning with Unsupervised Auxiliary Tasks” by Jaderberg, Mnih, Czarnecki et al. (https://arxiv.org/abs/1611.05397).
- Parameters
observations (Array) – A tensor of shape [T+1,H,W,C], where * T is the sequence length, * H is height, * W is width, * C is a channel dimension.
cell_size (int) – The size of each cell.
- Return type
Array
- Returns
A tensor of pixel control rewards calculated from the observation. The shape is [T,H’,W’], where H’=H/cell_size and W’=W/cell_size.
Feature Control Rewards¶
-
rlax.
feature_control_rewards
(features, cumulant_type='absolute_change', discount=None)[source]¶ Calculates cumulants for feature control tasks from a sequence of features.
For each feature dimension, a distinct pseudo reward is computed based on the change in the feature value between consecutive timesteps. Depending on cumulant_type, cumulants may be equal the features themselves, the absolute difference between their values in consecutive steps, their increase/decrease, or may take the form of a potential-based reward discounted by discount.
See “Reinforcement Learning with Unsupervised Auxiliary Tasks” by Jaderberg, Mnih, Czarnecki et al. (https://arxiv.org/abs/1611.05397).
- Parameters
features (Array) – A tensor of shape [T+1,D] of features.
cumulant_type – either ‘feature’ (feature is the reward), absolute_change (the reward equals the absolute difference between consecutive timesteps), increase (the reward equals the increase in the value of the feature), decrease (the reward equals the decrease in the value of the feature), or ‘potential’ (r=gamma*phi_{t+1} - phi_t).
discount – (optional) discount for potential based rewards.
- Return type
Array
- Returns
A tensor of cumulants calculated from the features. The shape is [T,D].
Pop Art¶
|
Adaptively rescale targets. |
|
Returns normalized values. |
|
Preserves outputs precisely. |
|
Generates functions giving initial PopArt state and update rule. |
|
|
|
Returns unnormalized values. |
|
Selects and unnormalizes output of a Linear. |
Art¶
-
rlax.
art
(state, targets, indices, step_size, scale_lb, scale_ub, axis_name=None)[source]¶ Adaptively rescale targets.
- Parameters
state (PopArtState) – The PopArt summary stats.
targets (Array) – targets which are rescaled.
indices (Array) – Which indices of the state to use.
step_size (float) – The step size for learning the scale & shift parameters.
scale_lb (float) – Lower bound for the scale.
scale_ub (float) – Upper bound for the scale.
axis_name – What axis to aggregate over, if str. If passed an iterable, aggregates over multiple axes. Defaults to no aggregation, i.e. None.
- Return type
- Returns
New popart state which can be used to rescale targets.
Normalize¶
-
rlax.
normalize
(state, unnormalized, indices)[source]¶ Returns normalized values.
- Parameters
state (PopArtState) – The PopArt summary stats.
unnormalized (Array) – unnormalized values that we applied PopArt to.
indices (Array) – Which scale and shifts to use
- Return type
Array
- Returns
Normalized PopArt values.
Pop¶
-
rlax.
pop
(params, old, new)[source]¶ Preserves outputs precisely.
- Parameters
params (LinearParams) – The parameters of the linear to preserve.
old (PopArtState) – The old PopArt state.
new (PopArtState) – The new PopArt state.
- Returns
new parameters.
PopArt¶
-
rlax.
popart
(num_outputs, step_size, scale_lb, scale_ub, axis_name=None)[source]¶ Generates functions giving initial PopArt state and update rule.
- Parameters
num_outputs (int) – The number of outputs generated by the linear we’re preserving.
step_size (float) – The step size for learning the scale & shift parameters.
scale_lb (float) – Lower bound for the scale.
scale_ub (float) – Upper bound for the scale.
axis_name – What axis to aggregate over, if str. If passed an iterable, aggregates over multiple axes. Defaults to no aggregation, i.e. None.
- Returns
initial_state: A function returning the initial PopArt state. popart_update: A function updating the PopArt state and parameters
of the preceding linear.
- Return type
A tuple of
PopArtState¶
Unnormalize¶
-
rlax.
unnormalize
(state, normalized, indices)[source]¶ Returns unnormalized values.
- Parameters
state (PopArtState) – The PopArt summary stats.
normalized (Array) – normalized values that we apply PopArt to.
indices (Array) – Which scale and shifts to use
- Return type
Array
- Returns
Unnormalized PopArt values.
Unnormalize Linear¶
-
rlax.
unnormalize_linear
(state, inputs, indices)[source]¶ Selects and unnormalizes output of a Linear.
- Parameters
state (PopArtState) – The PopArt summary stats.
inputs (Array) – The (normalized) output of the Linear that we apply PopArt to.
indices (Array) – Which indices of inputs to use.
- Return type
PopArtOutput
- Returns
PopArtOutput, a tuple of the normalized and unnormalized PopArt values.
Transforms¶
|
TxPair(apply, apply_inv) |
|
Identity transform. |
|
TxPair(apply, apply_inv) |
|
Logit transform, inverse of sigmoid. |
|
Power transform; power_tx(_, 1/p) is the inverse of power_tx(_, p). |
|
Sigmoid transform. |
|
Signed exponential of x - 1, inverse of signed_logp1. |
|
Signed hyperbolic transform, inverse of signed_parabolic. |
|
TxPair(apply, apply_inv) |
|
Signed logarithm of x + 1. |
|
TxPair(apply, apply_inv) |
|
Signed parabolic transform, inverse of signed_hyperbolic. |
|
Transforms from a categorical distribution to a scalar. |
|
Transforms a scalar tensor to a 2 hot representation. |
|
Power¶
Signed Exponential¶
Signed Hyperbolic¶
Signed Parabolic¶
Transform from 2 Hot¶
Losses¶
|
Caculates the L2 loss of predictions wrt targets. |
|
Calculates the likelihood of predictions wrt targets. |
|
Calculates the log loss of predictions wrt targets. |
|
Huber loss, similar to L2 loss close to zero, L1 loss away from zero. |
|
Calculate n-step Q-learning loss for pixel control auxiliary task. |
L2 Loss¶
-
rlax.
l2_loss
(predictions, targets=None)[source]¶ Caculates the L2 loss of predictions wrt targets.
If targets are not provided this function acts as an L2-regularizer for preds.
Note: the 0.5 term is standard in “Pattern Recognition and Machine Learning” by Bishop, but not “The Elements of Statistical Learning” by Tibshirani.
- Parameters
predictions (Array) – a vector of arbitrary shape.
targets (Optional[Array]) – a vector of shape compatible with predictions.
- Return type
Array
- Returns
a vector of same shape of predictions.
Likelihood¶
Log Loss¶
-
rlax.
log_loss
(predictions, targets)[source]¶ Calculates the log loss of predictions wrt targets.
- Parameters
predictions (Array) – a vector of probabilities of arbitrary shape.
targets (Array) – a vector of probabilities of shape compatible with predictions.
- Return type
Array
- Returns
a vector of same shape of predictions.
Huber Loss¶
-
rlax.
huber_loss
(x, delta=1.0)[source]¶ Huber loss, similar to L2 loss close to zero, L1 loss away from zero.
See “Robust Estimation of a Location Parameter” by Huber. (https://projecteuclid.org/download/pdf_1/euclid.aoms/1177703732).
- Parameters
x (Array) – a vector of arbitrary shape.
delta (float) – the bounds for the huber loss transformation, defaults at 1.
Note grad(huber_loss(x)) is equivalent to grad(0.5 * clip_gradient(x)**2).
- Return type
Array
- Returns
a vector of same shape of x.
Pixel Control Loss¶
rlax.
pixel_control_loss
(observations, actions, action_values, discount_factor, cell_size)[source]¶Calculate n-step Q-learning loss for pixel control auxiliary task.
For each pixel-based pseudo reward signal, the corresponding action-value function is trained off-policy, using Q(lambda). A discount of 0.9 is commonly used for learning the value functions.
Note that, since pseudo rewards have a spatial structure, with neighbouring cells exhibiting strong correlations, it is convenient to predict the action values for all the cells through a deconvolutional head.
See “Reinforcement Learning with Unsupervised Auxiliary Tasks” by Jaderberg, Mnih, Czarnecki et al. (https://arxiv.org/abs/1611.05397).
- Parameters
observations (Array) – A tensor of shape [T+1, …]; … is the observation shape, T the sequence length.
actions (Array) – A tensor, shape [T,], of the actions across each sequence.
action_values (Array) – A tensor, shape [T+1, H, W, N] of pixel control action values, where H, W are the number of pixel control cells/tasks, and N is the number of actions.
discount_factor (Union[Array, Scalar]) – discount used for learning the value function associated to the pseudo rewards; must be a scalar or a Tensor of shape [T].
cell_size (int) – size of the cells used to derive the pixel based pseudo-rewards.
- Returns
a tensor containing the spatial loss, shape [T, H, W].
- Raises
ValueError – if the shape of action_values is not compatible with that of the pseudo-rewards derived from the observations.
Distributions¶
|
Computes the softmax cross entropy between sets of logits and labels. |
Compute importance sampling ratios from logits. |
|
|
Compute the KL between two categorical distributions from their logits. |
|
Sample from a set of discrete probabilities. |
|
A softmax distribution with clipped entropy (1 is eq to not clipping). |
|
An epsilon-greedy distribution. |
|
An epsilon-softmax distribution. |
|
A gaussian distribution with diagonal covariance matrix. |
|
A greedy distribution. |
|
Compute the KL between 2 gaussian distrs with diagonal covariance matrices. |
|
Tolerantly handles the temperature=0 case. |
|
A softmax distribution. |
|
A squashed gaussian distribution with diagonal covariance matrix. |
Categorical Cross Entropy¶
-
rlax.
categorical_cross_entropy
(labels, logits)[source]¶ Computes the softmax cross entropy between sets of logits and labels.
See “Deep Learning” by Goodfellow et al. (http://www.deeplearningbook.org/contents/prob.html).
- Parameters
labels (Array) – a valid probability distribution (non-negative, sum to 1).
logits (Array) – unnormalized log probabilities.
- Return type
Array
- Returns
a scalar loss.
Categorical Importance Sampling Ratios¶
-
rlax.
categorical_importance_sampling_ratios
(pi_logits_t, mu_logits_t, a_t)[source]¶ Compute importance sampling ratios from logits.
- Parameters
pi_logits_t (Array) – unnormalized logits at time t for the target policy.
mu_logits_t (Array) – unnormalized logits at time t for the behavior policy.
a_t (Array) – actions at time t.
- Return type
Array
- Returns
importance sampling ratios.
Categorical KL Divergence¶
-
rlax.
categorical_kl_divergence
(p_logits, q_logits, temperature=1.0)[source]¶ Compute the KL between two categorical distributions from their logits.
- Parameters
p_logits (Array) – unnormalized logits for the first distribution.
q_logits (Array) – unnormalized logits for the second distribution.
temperature (float) – the temperature for the softmax distribution, defaults at 1.
- Return type
Array
- Returns
the kl divergence between the distributions.
Categorical Sample¶
Clipped Entropy Softmax¶
Epsilon Softmax¶
Gaussian Diagonal¶
Multivariate Normal KL Divergence¶
-
rlax.
multivariate_normal_kl_divergence
(mu_0, sigma_0, mu_1, sigma_1)[source]¶ Compute the KL between 2 gaussian distrs with diagonal covariance matrices.
- Parameters
mu_0 (Array) – array like of mean values for policy 0
sigma_0 (Numeric) – array like of std values for policy 0
mu_1 (Array) – array like of mean values for policy 1
sigma_1 (Numeric) – array like of std values for policy 1
- Return type
Array
- Returns
the kl divergence between the distributions.