Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Segmentation Fault in PrioritizedSliceSampler.sample() #2206

Closed
3 tasks done
wertyuilife2 opened this issue Jun 6, 2024 · 2 comments · Fixed by #2202
Closed
3 tasks done

[BUG] Segmentation Fault in PrioritizedSliceSampler.sample() #2206

wertyuilife2 opened this issue Jun 6, 2024 · 2 comments · Fixed by #2202
Assignees
Labels
bug Something isn't working

Comments

@wertyuilife2
Copy link

wertyuilife2 commented Jun 6, 2024

Describe the bug

This issue comes from the original issue #2205.

In PrioritizedSliceSampler.sample() , preceding_stop_idx needs to be moved to the CPU before executing self._sum_tree[preceding_stop_idx] = 0.0. If preceding_stop_idx is on the GPU, the program results in a segmentation fault.

To Reproduce

The code below will cause segmentation fault.

import torch
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler
from tensordict import TensorDict

def test_sampler():
    torch.manual_seed(0)

    sampler = PrioritizedSliceSampler(
        max_capacity=20,
        num_slices=2,
        traj_key="trajectory",
        strict_length=True,
        alpha=1.0,
        beta=1.0,
    )
    trajectory = torch.tensor([3, 3, 0, 1, 1, 1, 2, 2, 2, 3])
    td = TensorDict({"trajectory": trajectory, "steps": torch.arange(10)}, [10])
    rb = ReplayBuffer(
        sampler=sampler,
        storage=LazyTensorStorage(20, device=torch.device("cuda")),
        batch_size=6,
    )

    rb.extend(td)
    for i in range(10):
        traj = rb.sample()["trajectory"]
        print("[loop {}]sampled trajectory: {}".format(i, traj))
test_sampler()

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@wertyuilife2 wertyuilife2 added the bug Something isn't working label Jun 6, 2024
@vmoens
Copy link
Contributor

vmoens commented Jun 6, 2024

Should be solved by #2202

@vmoens vmoens linked a pull request Jun 6, 2024 that will close this issue
3 tasks
@vmoens
Copy link
Contributor

vmoens commented Jun 7, 2024

Not solved yet - bear with me

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants