Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] D4rl direct download #1430

Merged
merged 14 commits into from
Oct 4, 2023
Prev Previous commit
Next Next commit
test direct download
vmoens committed Oct 4, 2023
commit 920163e2cd29ec0589456a5b44be00b6efff590d
27 changes: 27 additions & 0 deletions test/test_libs.py
Original file line number Diff line number Diff line change
@@ -1823,6 +1823,33 @@ def test_terminate_on_end(self, task, use_truncated_as_done, split_trajs):
]
assert "truncated" not in leaf_names

@pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"])
def test_direct_download(self, task):
data_direct = D4RLExperienceReplay(
task,
split_trajs=False,
from_env=False,
batch_size=2,
use_truncated_as_done=True,
direct_download=True,
)
data_d4rl = D4RLExperienceReplay(
task,
split_trajs=False,
from_env=False,
batch_size=2,
use_truncated_as_done=True,
direct_download=False,
terminate_on_end=True, # keep the last time step
)
keys = set(data_direct._storage._storage.keys(True, True))
keys = keys.intersection(data_d4rl._storage._storage.keys(True, True))
assert len(keys)
assert_allclose_td(
data_direct._storage._storage.select(*keys),
data_d4rl._storage._storage.select(*keys),
)

@pytest.mark.parametrize(
"task",
[
53 changes: 42 additions & 11 deletions torchrl/data/datasets/d4rl.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import os
import urllib
import warnings
@@ -80,13 +82,19 @@ class D4RLExperienceReplay(TensorDictReplayBuffer):
differ. In particular, the ``"truncated"`` key (used to determine the
end of an episode) may be absent when ``from_env=False`` but present
otherwise, leading to a different slicing when ``traj_splits`` is enabled.
direct_download (bool): if ``True`` (default), the data will be downloaded without
requiring D4RL. This is not compatible with ``from_env=True``.
direct_download (bool): if ``True``, the data will be downloaded without
requiring D4RL. If ``None``, if ``d4rl`` is present in the env it will
be used to download the dataset, otherwise the download will fall back
on ``direct_download=True``.
This is not compatible with ``from_env=True``.
Defaults to ``None``.
use_truncated_as_done (bool, optional): if ``True``, ``done = terminated | truncated``.
Otherwise, only the ``terminated`` key is used. Defaults to ``True``.
terminate_on_end (bool, optional): Set ``done=True`` on the last timestep
in a trajectory. Default is ``False``, and will discard the
last timestep in each trajectory.
**env_kwargs (key-value pairs): additional kwargs for
:func:`d4rl.qlearning_dataset`. Supports ``terminate_on_end``
(``False`` by default) or other kwargs if defined by D4RL library.
:func:`d4rl.qlearning_dataset`.


Examples:
@@ -115,16 +123,17 @@ def __init__(
self,
name,
batch_size: int,
sampler: Optional[Sampler] = None,
writer: Optional[Writer] = None,
collate_fn: Optional[Callable] = None,
sampler: Sampler | None = None,
writer: Writer | None = None,
collate_fn: Callable | None = None,
pin_memory: bool = False,
prefetch: Optional[int] = None,
transform: Optional["Transform"] = None, # noqa-F821
prefetch: int | None = None,
transform: "torchrl.envs.Transform" | None = None, # noqa-F821
split_trajs: bool = False,
from_env: bool = None,
use_truncated_as_done: bool = True,
direct_download: bool = True,
direct_download: bool = None,
terminate_on_end: bool = None,
**env_kwargs,
):
if from_env is None:
@@ -139,7 +148,16 @@ def __init__(
from_env = True
self.from_env = from_env
self.use_truncated_as_done = use_truncated_as_done

if not from_env and direct_download is None:
self._import_d4rl()
direct_download = not self._has_d4rl

if not direct_download:
if terminate_on_end is None:
# we use the default of d4rl
terminate_on_end = False
env_kwargs.update({"terminate_on_end": terminate_on_end})
self._import_d4rl()

if not self._has_d4rl:
@@ -148,8 +166,19 @@ def __init__(
if from_env:
dataset = self._get_dataset_from_env(name, env_kwargs)
else:
if self.use_truncated_as_done:
warnings.warn(
"Using use_truncated_as_done=True + terminate_on_end=True "
"with from_env=False may not have the intended effect "
"as the timeouts (truncation) "
"can be absent from the static dataset."
)
dataset = self._get_dataset_direct(name, env_kwargs)
else:
if terminate_on_end is False:
raise ValueError(
"Using terminate_on_end=False is not compatible with direct_download=True."
)
dataset = self._get_dataset_direct_download(name, env_kwargs)
# Fill unknown next states with 0
dataset["next", "observation"][dataset["next", "done"].squeeze()] = 0
@@ -174,7 +203,9 @@ def __init__(
def _get_dataset_direct_download(self, name, env_kwargs):
"""Directly download and use a D4RL dataset."""
if env_kwargs:
raise RuntimeError("Cannot pass env_kwargs when `direct_download=True`.")
raise RuntimeError(
f"Cannot pass env_kwargs when `direct_download=True`. Got env_kwargs keys: {env_kwargs.keys()}"
)
url = D4RL_DATASETS.get(name, None)
if url is None:
raise KeyError(f"Env {name} not found.")