Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] D4rl direct download #1430

Merged
merged 14 commits into from
Oct 4, 2023
Prev Previous commit
Next Next commit
fix
vmoens committed Oct 4, 2023
commit 16cdadb936f5a4dea75460eb1758ce5f0dace68c
12 changes: 8 additions & 4 deletions test/test_libs.py
Original file line number Diff line number Diff line change
@@ -1775,7 +1775,7 @@ class TestD4RL:
def test_terminate_on_end(self, task, use_truncated_as_done, split_trajs):

with pytest.warns(
UserWarning, match="Using terminate_on_end=True with from_env=False"
UserWarning, match="Using use_truncated_as_done=True"
) if use_truncated_as_done else nullcontext():
data_true = D4RLExperienceReplay(
task,
@@ -1836,7 +1836,7 @@ def test_direct_download(self, task):
data_d4rl = D4RLExperienceReplay(
task,
split_trajs=False,
from_env=False,
from_env=True,
batch_size=2,
use_truncated_as_done=True,
direct_download=False,
@@ -1846,8 +1846,12 @@ def test_direct_download(self, task):
keys = keys.intersection(data_d4rl._storage._storage.keys(True, True))
assert len(keys)
assert_allclose_td(
data_direct._storage._storage.select(*keys),
data_d4rl._storage._storage.select(*keys),
data_direct._storage._storage.select(*keys).apply(
lambda t: t.as_tensor().float()
),
data_d4rl._storage._storage.select(*keys).apply(
lambda t: t.as_tensor().float()
),
)

@pytest.mark.parametrize(
2 changes: 1 addition & 1 deletion torchrl/data/datasets/d4rl.py
Original file line number Diff line number Diff line change
@@ -157,7 +157,6 @@ def __init__(
if terminate_on_end is None:
# we use the default of d4rl
terminate_on_end = False
env_kwargs.update({"terminate_on_end": terminate_on_end})
self._import_d4rl()

if not self._has_d4rl:
@@ -173,6 +172,7 @@ def __init__(
"as the timeouts (truncation) "
"can be absent from the static dataset."
)
env_kwargs.update({"terminate_on_end": terminate_on_end})
dataset = self._get_dataset_direct(name, env_kwargs)
else:
if terminate_on_end is False: