Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] TrajCounter transform #2532

Merged
merged 3 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,7 @@ to be able to create this other composition:
TensorDictPrimer
TimeMaxPool
ToTensorImage
TrajCounter
UnsqueezeTransform
VC1Transform
VIPRewardTransform
Expand Down
210 changes: 210 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import tensordict.tensordict
import torch

from torchrl.collectors import MultiSyncDataCollector

if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import ( # noqa
BREAKOUT_VERSIONED,
Expand Down Expand Up @@ -135,6 +137,7 @@
TensorDictPrimer,
TimeMaxPool,
ToTensorImage,
TrajCounter,
TransformedEnv,
UnsqueezeTransform,
VC1Transform,
Expand Down Expand Up @@ -1926,6 +1929,213 @@ def test_stepcounter_ignore(self):
assert env.transform.step_count_keys[0] == ("data", "step_count")


class TestTrajCounter(TransformBase):
def test_single_trans_env_check(self):
torch.manual_seed(0)
env = TransformedEnv(CountingEnv(max_steps=4), TrajCounter())
env.transform.transform_observation_spec(env.base_env.observation_spec)
check_env_specs(env)

@pytest.mark.parametrize("predefined", [True, False])
def test_parallel_trans_env_check(self, predefined):
if predefined:
t = TrajCounter()
else:
t = None

def make_env(max_steps=4, t=t):
if t is None:
t = TrajCounter()
env = TransformedEnv(CountingEnv(max_steps=max_steps), t.clone())
env.transform.transform_observation_spec(env.base_env.observation_spec)
return env

if predefined:
penv = ParallelEnv(
2,
[EnvCreator(make_env, max_steps=4), EnvCreator(make_env, max_steps=5)],
mp_start_method="spawn",
)
else:
make_env_c0 = EnvCreator(make_env)
make_env_c1 = make_env_c0.make_variant(max_steps=5)
penv = ParallelEnv(
2,
[make_env_c0, make_env_c1],
mp_start_method="spawn",
)

r = penv.rollout(100, break_when_any_done=False)
s0 = set(r[0]["traj_count"].squeeze().tolist())
s1 = set(r[1]["traj_count"].squeeze().tolist())
assert len(s1.intersection(s0)) == 0

@pytest.mark.parametrize("predefined", [True, False])
def test_serial_trans_env_check(self, predefined):
if predefined:
t = TrajCounter()
else:
t = None

def make_env(max_steps=4, t=t):
if t is None:
t = TrajCounter()
else:
t = t.clone()
env = TransformedEnv(CountingEnv(max_steps=max_steps), t)
env.transform.transform_observation_spec(env.base_env.observation_spec)
return env

if predefined:
penv = SerialEnv(
2,
[EnvCreator(make_env, max_steps=4), EnvCreator(make_env, max_steps=5)],
)
else:
make_env_c0 = EnvCreator(make_env)
make_env_c1 = make_env_c0.make_variant(max_steps=5)
penv = SerialEnv(
2,
[make_env_c0, make_env_c1],
)

r = penv.rollout(100, break_when_any_done=False)
s0 = set(r[0]["traj_count"].squeeze().tolist())
s1 = set(r[1]["traj_count"].squeeze().tolist())
assert len(s1.intersection(s0)) == 0

def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
env = TransformedEnv(
maybe_fork_ParallelEnv(
2, [lambda: CountingEnv(max_steps=4), lambda: CountingEnv(max_steps=5)]
),
TrajCounter(),
)
env.transform.transform_observation_spec(env.base_env.observation_spec)
r = env.rollout(
100,
lambda td: td.set("action", torch.ones(env.shape + (1,))),
break_when_any_done=False,
)
check_env_specs(env)
assert r["traj_count"].max() == 36

def test_trans_serial_env_check(self):
env = TransformedEnv(
SerialEnv(
2, [lambda: CountingEnv(max_steps=4), lambda: CountingEnv(max_steps=5)]
),
TrajCounter(),
)
env.transform.transform_observation_spec(env.base_env.observation_spec)
r = env.rollout(
100,
lambda td: td.set("action", torch.ones(env.shape + (1,))),
break_when_any_done=False,
)
check_env_specs(env)
assert r["traj_count"].max() == 36

def test_transform_env(self):
torch.manual_seed(0)
env = TransformedEnv(CountingEnv(max_steps=4), TrajCounter())
env.transform.transform_observation_spec(env.base_env.observation_spec)
r = env.rollout(100, lambda td: td.set("action", 1), break_when_any_done=False)
assert r["traj_count"].max() == 19

def test_nested(self):
torch.manual_seed(0)
env = TransformedEnv(
CountingEnv(max_steps=4),
Compose(
RenameTransform("done", ("nested", "done"), create_copy=True),
TrajCounter(out_key=(("nested"), (("traj_count",),))),
),
)
env.transform.transform_observation_spec(env.base_env.observation_spec)
r = env.rollout(100, lambda td: td.set("action", 1), break_when_any_done=False)
assert r["nested", "traj_count"].max() == 19

@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
def test_transform_rb(self, rbclass):
t = TrajCounter()
rb = rbclass(storage=LazyTensorStorage(20))
rb.append_transform(t)
td = (
TensorDict(
{("next", "observation"): torch.randn(3), "action": torch.randn(2)}, []
)
.expand(10)
.contiguous()
)
rb.extend(td)
with pytest.raises(
RuntimeError,
match="TrajCounter can only be called within an environment step or reset",
):
td = rb.sample(10)

def test_collector_match(self):
# The counter in the collector should match the one from the transform
t = TrajCounter()

def make_env(max_steps=4):
env = TransformedEnv(CountingEnv(max_steps=max_steps), t.clone())
env.transform.transform_observation_spec(env.base_env.observation_spec)
return env

collector = MultiSyncDataCollector(
[EnvCreator(make_env, max_steps=5), EnvCreator(make_env, max_steps=4)],
total_frames=99,
frames_per_batch=8,
)
for d in collector:
# The env has one more traj because the collector calls reset during init
assert d["collector", "traj_ids"].max() == d["next", "traj_count"].max() - 1
assert d["traj_count"].max() > 0

def test_transform_compose(self):
t = TrajCounter()
t = nn.Sequential(t)
td = (
TensorDict(
{("next", "observation"): torch.randn(3), "action": torch.randn(2)}, []
)
.expand(10)
.contiguous()
)

with pytest.raises(
RuntimeError,
match="TrajCounter can only be called within an environment step or reset",
):
td = t(td)

def test_transform_inverse(self):
pytest.skip("No inverse transform for TrajCounter")

def test_transform_model(self):
t = TrajCounter()
td = (
TensorDict(
{("next", "observation"): torch.randn(3), "action": torch.randn(2)}, []
)
.expand(10)
.contiguous()
)

with pytest.raises(
RuntimeError,
match="TrajCounter can only be called within an environment step or reset",
):
td = t(td)

@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("batch", [[], [4], [6, 4]])
def test_transform_no_env(self, device, batch):
pytest.skip("TrajCounter cannot be called without env")


class TestCatTensors(TransformBase):
@pytest.mark.parametrize("append", [True, False])
def test_cattensors_empty(self, append):
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
TensorDictPrimer,
TimeMaxPool,
ToTensorImage,
TrajCounter,
Transform,
TransformedEnv,
UnsqueezeTransform,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1408,7 +1408,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
# No certainty which module multiprocessing_context is
parent_pipe, child_pipe = ctx.Pipe()
env_fun = self.create_env_fn[idx]
if not isinstance(env_fun, EnvCreator):
if not isinstance(env_fun, (EnvCreator, CloudpickleWrapper)):
env_fun = CloudpickleWrapper(env_fun)
kwargs[idx].update(
{
Expand Down
75 changes: 70 additions & 5 deletions torchrl/envs/env_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from __future__ import annotations

from collections import OrderedDict
from multiprocessing.sharedctypes import Synchronized
from typing import Callable, Dict, Optional, Union

import torch
Expand Down Expand Up @@ -33,6 +34,8 @@ class EnvCreator:
create_env_kwargs (dict, optional): the kwargs of the env creator.
share_memory (bool, optional): if False, the resulting tensordict
from the environment won't be placed in shared memory.
**kwargs: additional keyword arguments to be passed to the environment
during construction.

Examples:
>>> # We create the same environment on 2 processes using VecNorm
Expand Down Expand Up @@ -79,20 +82,38 @@ def __init__(
create_env_fn: Callable[..., EnvBase],
create_env_kwargs: Optional[Dict] = None,
share_memory: bool = True,
**kwargs,
) -> None:
if not isinstance(create_env_fn, EnvCreator):
if not isinstance(create_env_fn, (EnvCreator, CloudpickleWrapper)):
self.create_env_fn = CloudpickleWrapper(create_env_fn)
else:
self.create_env_fn = create_env_fn

self.create_env_kwargs = (
create_env_kwargs if isinstance(create_env_kwargs, dict) else {}
)
self.create_env_kwargs = kwargs
if isinstance(create_env_kwargs, dict):
self.create_env_kwargs.update(create_env_kwargs)
self.initialized = False
self._meta_data = None
self._share_memory = share_memory
self.init_()

def make_variant(self, **kwargs) -> EnvCreator:
"""Creates a variant of the EnvCreator, pointing to the same underlying metadata but with different keyword arguments during construction.

This can be useful with transforms that share a state, like :class:`~torchrl.envs.TrajCounter`.

Examples:
>>> from torchrl.envs import GymEnv
>>> env_creator_pendulum = EnvCreator(GymEnv, env_name="Pendulum-v1")
>>> env_creator_cartpole = env_creator_pendulum(env_name="CartPole-v1")

"""
# Copy self
out = type(self).__new__(type(self))
out.__dict__.update(self.__dict__)
out.create_env_kwargs.update(kwargs)
return out

def share_memory(self, state_dict: OrderedDict) -> None:
for key, item in list(state_dict.items()):
if isinstance(item, (TensorDictBase,)):
Expand All @@ -101,7 +122,7 @@ def share_memory(self, state_dict: OrderedDict) -> None:
else:
torchrl_logger.info(
f"{self.env_type}: {item} is already shared"
) # , deleting key')
) # , deleting key'val)
del state_dict[key]
elif isinstance(item, OrderedDict):
self.share_memory(item)
Expand All @@ -120,12 +141,43 @@ def meta_data(self) -> EnvMetaData:
def meta_data(self, value: EnvMetaData):
self._meta_data = value

@staticmethod
def _is_mp_value(val):

return isinstance(val, (Synchronized,)) and hasattr(val, "_obj")

@classmethod
def _find_mp_values(cls, env_or_transform, values, prefix=()):
from torchrl.envs.transforms.transforms import Compose, TransformedEnv

if isinstance(env_or_transform, EnvBase) and isinstance(
env_or_transform, TransformedEnv
):
cls._find_mp_values(
env_or_transform.transform,
values=values,
prefix=prefix + ("transform",),
)
cls._find_mp_values(
env_or_transform.base_env, values=values, prefix=prefix + ("base_env",)
)
elif isinstance(env_or_transform, Compose):
for i, t in enumerate(env_or_transform.transforms):
cls._find_mp_values(t, values=values, prefix=prefix + (i,))
for k, v in env_or_transform.__dict__.items():
if cls._is_mp_value(v):
values.append((prefix + (k,), v))
return values

def init_(self) -> EnvCreator:
shadow_env = self.create_env_fn(**self.create_env_kwargs)
tensordict = shadow_env.reset()
shadow_env.rand_step(tensordict)
self.env_type = type(shadow_env)
self._transform_state_dict = shadow_env.state_dict()
# Extract any mp.Value object from the env
self._mp_values = self._find_mp_values(shadow_env, values=[])

if self._share_memory:
self.share_memory(self._transform_state_dict)
self.initialized = True
Expand All @@ -134,11 +186,24 @@ def init_(self) -> EnvCreator:
del shadow_env
return self

@classmethod
def _set_mp_value(cls, env, key, value):
if len(key) > 1:
if isinstance(key[0], int):
return cls._set_mp_value(env[key[0]], key[1:], value)
else:
return cls._set_mp_value(getattr(env, key[0]), key[1:], value)
else:
setattr(env, key[0], value)

def __call__(self, **kwargs) -> EnvBase:
if not self.initialized:
raise RuntimeError("EnvCreator must be initialized before being called.")
kwargs.update(self.create_env_kwargs) # create_env_kwargs precedes
env = self.create_env_fn(**kwargs)
if self._mp_values:
for k, v in self._mp_values:
self._set_mp_value(env, k, v)
env.load_state_dict(self._transform_state_dict, strict=False)
return env

Expand Down
Loading
Loading