-
Notifications
You must be signed in to change notification settings - Fork 353
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] SliceSampler breaks when at capacity #1969
Comments
Interesting! I guess the natural fit would be to tell SliceSampler that this is the same traj! I can patch that tomorrow |
I gave this a deeper thought and it's even trickier than I initially thought.
and after a while
Is (1) can be relatively easily solved, but (2) will persist. So we should consider one of these options or a combination of them:
@nicklashansen what's the expected behaviour in your case in the example I gave? What would be the "natural" thing to do? IMO (1) needs to be fixed for sure and strict length set to ccing @Cadene since we had a couple of conversations on the topic. |
@vmoens: this is the current usage of SliceSampler in TD-MPC2 https://github.com/nicklashansen/tdmpc2/blob/57158282b46ebc5c329c5be9cfe2b0094126d1ca/tdmpc2/common/buffer.py#L17 which sometimes results in the above error when the replay buffer becomes full. I think the most natural behavior would be to just not sample trajectories shorter than the specified length if cc colleagues @dasGringuen @aalmuzairee who encountered this error |
@vmoens sorry to reopen this but we're still encountering errors when the replay buffer hits capacity. This is the error that I'm getting using
which is also encountered by @wertyuilife in issue nicklashansen/tdmpc2#20. My specific implementation here uses variable episode length of 4-251 and encounters this error only at capacity, regardless of what I set the capacity to. I believe @wertyuilife encounters this error using the official TD-MPC2 repo which uses fixed episode length, so it appears to be a more persistent issue. |
No need to be sorry, I will investigate! |
Have you set strict_length=False by the way? |
I can reproduce this but it's expected when strict_length=True import torch
import tqdm
from tensordict import TensorDict
from torchrl.data import ReplayBuffer, SliceSampler, LazyTensorStorage
rb = ReplayBuffer(storage=LazyTensorStorage(1000),
sampler=SliceSampler(slice_len=4, traj_key="traj", strict_length=False), batch_size=256) # Change strict_length=True to get the error
for i in tqdm.tqdm(range(10_000)):
n = torch.randint(2, 50, ()).item()
td = TensorDict({"a": torch.randn(n, 3), "traj": torch.full((n, ), i, dtype=torch.float32)}, [n])
rb.extend(td)
if i > 10:
rb.sample() |
This is the current usage of the SliceSampler: https://github.com/nicklashansen/tdmpc2/blob/57158282b46ebc5c329c5be9cfe2b0094126d1ca/tdmpc2/common/buffer.py#L17
which returns the error
when
presumably because the implementation returns a different number of elements than expected? @vmoens What is the recommended way of using this sampler? Is there any way to just not sample sequences that are invalid? As it is, it seems like my implementation will break one way or the other as soon as the buffer hits capacity. I'm happy to change my own implementation to accommodate the SliceSampler but don't quite see an easy solution at the moment. Also CC @wertyuilife @dasGringuen @rokas-bendikas @jyothirsv who are affected by this |
I think I see clearer now. |
Sounds great. Thank you! |
@vmoens I have checked that it no longer throws an exception for me and pushed an update to tdmpc2 here: nicklashansen/tdmpc2@5f6fade Thanks again for your help! |
Hi @vmoens, I ran into this issue with the SliceSampler which does not appear to be intentional behavior!
Describe the bug
The current implementation of SliceSampler will raise the exception
when an added episode of length greater than slice_len is added to the replay buffer while it is close to capacity. It appears that episodes are "wrapped around" to the beginning of the replay buffer but that the sampler does not account for this and thus raises an exception.
This issue affects all use cases for which the replay buffer capacity is not a multiple of the episode length (or episodes with varying length).
To Reproduce
The error can be reproduced by running the following example code:
Expected behavior
I would expect the sampler to consider the off chance that an episode may be split between end indices and start indices due to the replay buffer being at capacity.
Checklist
The text was updated successfully, but these errors were encountered: