Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for setting the target entropy #43

Merged
merged 4 commits into from
Apr 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions sbx/crossq/crossq.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union
from typing import Any, ClassVar, Dict, Literal, Optional, Tuple, Type, Union

import flax
import flax.linen as nn
@@ -66,6 +66,7 @@ def __init__(
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
ent_coef: Union[str, float] = "auto",
target_entropy: Union[Literal["auto"], float] = "auto",
use_sde: bool = False,
sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False,
@@ -103,6 +104,7 @@ def __init__(

self.policy_delay = policy_delay
self.ent_coef_init = ent_coef
self.target_entropy = target_entropy

if _init_setup_model:
self._setup_model()
@@ -155,8 +157,14 @@ def _setup_model(self) -> None:
),
)

# automatically set target entropy if needed
self.target_entropy = -np.prod(self.action_space.shape).astype(np.float32)
# Target entropy is used when learning the entropy coefficient
if self.target_entropy == "auto":
# automatically set target entropy if needed
self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) # type: ignore
else:
# Force conversion
# this will also throw an error for unexpected string
self.target_entropy = float(self.target_entropy)

def learn(
self,
@@ -251,7 +259,6 @@ def update_critic(
def mse_loss(
params: flax.core.FrozenDict, batch_stats: flax.core.FrozenDict, dropout_key: flax.core.FrozenDict
) -> Tuple[jax.Array, jax.Array]:

# Joint forward pass of obs/next_obs and actions/next_state_actions to have only
# one forward pass with shape (n_critics, 2 * batch_size, 1).
#
14 changes: 11 additions & 3 deletions sbx/sac/sac.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union
from typing import Any, ClassVar, Dict, Literal, Optional, Tuple, Type, Union

import flax
import flax.linen as nn
@@ -67,6 +67,7 @@ def __init__(
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
ent_coef: Union[str, float] = "auto",
target_entropy: Union[Literal["auto"], float] = "auto",
use_sde: bool = False,
sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False,
@@ -105,6 +106,7 @@ def __init__(

self.policy_delay = policy_delay
self.ent_coef_init = ent_coef
self.target_entropy = target_entropy

if _init_setup_model:
self._setup_model()
@@ -157,8 +159,14 @@ def _setup_model(self) -> None:
),
)

# automatically set target entropy if needed
self.target_entropy = -np.prod(self.action_space.shape).astype(np.float32)
# Target entropy is used when learning the entropy coefficient
if self.target_entropy == "auto":
# automatically set target entropy if needed
self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) # type: ignore
else:
# Force conversion
# this will also throw an error for unexpected string
self.target_entropy = float(self.target_entropy)

def learn(
self,
15 changes: 12 additions & 3 deletions sbx/tqc/tqc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union
from typing import Any, ClassVar, Dict, Literal, Optional, Tuple, Type, Union

import flax
import flax.linen as nn
@@ -68,6 +68,7 @@ def __init__(
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
ent_coef: Union[str, float] = "auto",
target_entropy: Union[Literal["auto"], float] = "auto",
use_sde: bool = False,
sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False,
@@ -106,6 +107,8 @@ def __init__(

self.policy_delay = policy_delay
self.ent_coef_init = ent_coef
self.target_entropy = target_entropy

self.policy_kwargs["top_quantiles_to_drop_per_net"] = top_quantiles_to_drop_per_net

if _init_setup_model:
@@ -159,8 +162,14 @@ def _setup_model(self) -> None:
),
)

# automatically set target entropy if needed
self.target_entropy = -np.prod(self.action_space.shape).astype(np.float32)
# Target entropy is used when learning the entropy coefficient
if self.target_entropy == "auto":
# automatically set target entropy if needed
self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) # type: ignore
else:
# Force conversion
# this will also throw an error for unexpected string
self.target_entropy = float(self.target_entropy)

def learn(
self,
2 changes: 1 addition & 1 deletion sbx/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.13.0
0.14.0
1 change: 1 addition & 0 deletions tests/test_run.py
Original file line number Diff line number Diff line change
@@ -69,6 +69,7 @@ def test_tqc(tmp_path) -> None:
gradient_steps=1,
use_sde=True,
qf_learning_rate=1e-3,
target_entropy=-10,
)
model.learn(200)
check_save_load(model, TQC, tmp_path)