-
Notifications
You must be signed in to change notification settings - Fork 350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] PPOs with composite distribution #2791
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2791
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 21 PendingAs of commit eadb9e1 with merge base 27a8ecc ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked that all PPO tests are still passing; two comments:
- We should probably have test to catch this,
- I'm planning to do a docstring pass on the entire PPO stack so things can be much clearer (some operations are a bit obscure at first read).
if is_tensor_collection(log_weight): | ||
log_weight = _sum_td_features(log_weight) | ||
log_weight = log_weight.view(adv_shape).unsqueeze(-1) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is the main change for this method, which is also now consistent with type hints.
@@ -987,8 +982,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: | |||
# to different, unrelated trajectories, which is not standard. Still, it can give an idea of the weights' | |||
# dispersion. | |||
lw = log_weight.squeeze() | |||
if not isinstance(lw, torch.Tensor): | |||
lw = _sum_td_features(lw) | |||
ess = (2 * lw.logsumexp(0) - (2 * lw).logsumexp(0)).exp() | |||
batch = log_weight.shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main error is two lines below; clamp
was applied to the TensorDict log_weight
before it is summed over the feature dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks
Co-authored-by: Louis Faury <[email protected]> (cherry picked from commit edfa25d)
Description
I believe there is a bug in PPOs' implementation when both
prev_log_prob
andlog_prob
are TensorDicts.Motivation and Context
In the setting were both
prev_log_prob
andlog_prob
are TensorDicts, we were clampingprev_log_prob - log_prob
directly, instead of their sum over features.Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!