-
Notifications
You must be signed in to change notification settings - Fork 350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Refactor categorical dists: Masked one-hot and pass-through gradients #1488
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_single | 0.1418s | 0.1396s | 7.1611 Ops/s | 7.1508 Ops/s | |
test_sync | 0.1479s | 78.5237ms | 12.7350 Ops/s | 12.5868 Ops/s | |
test_async | 0.1954s | 72.7233ms | 13.7507 Ops/s | 13.8291 Ops/s | |
test_simple | 0.6887s | 0.6220s | 1.6076 Ops/s | 1.6173 Ops/s | |
test_transformed | 1.6858s | 1.6309s | 0.6132 Ops/s | 0.6142 Ops/s | |
test_serial | 1.7625s | 1.7123s | 0.5840 Ops/s | 0.5791 Ops/s | |
test_parallel | 1.4972s | 1.4378s | 0.6955 Ops/s | 0.6795 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.1794ms | 44.9887μs | 22.2278 KOps/s | 22.1773 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 0.2239ms | 25.5373μs | 39.1585 KOps/s | 38.7855 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 96.6010μs | 31.5141μs | 31.7318 KOps/s | 31.6983 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 38.6000μs | 17.5417μs | 57.0070 KOps/s | 56.1848 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 94.9010μs | 46.7277μs | 21.4006 KOps/s | 21.4075 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 71.1010μs | 27.5730μs | 36.2674 KOps/s | 36.5384 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 0.3685ms | 33.7708μs | 29.6114 KOps/s | 29.3186 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 68.8010μs | 19.8516μs | 50.3738 KOps/s | 51.3053 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 79.6010μs | 48.5993μs | 20.5764 KOps/s | 20.5366 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 60.2000μs | 29.5101μs | 33.8867 KOps/s | 34.1279 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 0.1010ms | 33.3844μs | 29.9541 KOps/s | 29.4999 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 84.5010μs | 19.8849μs | 50.2894 KOps/s | 51.0246 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 0.1720ms | 49.9006μs | 20.0398 KOps/s | 19.8256 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 54.4000μs | 31.0639μs | 32.1917 KOps/s | 32.6176 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 58.5010μs | 35.7756μs | 27.9520 KOps/s | 28.4905 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 69.7000μs | 21.3991μs | 46.7310 KOps/s | 47.3521 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 0.1369ms | 48.5873μs | 20.5815 KOps/s | 20.7081 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 54.4000μs | 29.4367μs | 33.9712 KOps/s | 34.2717 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 82.9010μs | 38.1238μs | 26.2304 KOps/s | 26.7714 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 3.3822ms | 21.8752μs | 45.7139 KOps/s | 45.2709 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 94.8010μs | 50.5544μs | 19.7807 KOps/s | 19.9191 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 62.3010μs | 31.0424μs | 32.2140 KOps/s | 32.0972 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 84.2010μs | 39.6293μs | 25.2339 KOps/s | 25.8375 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 0.9742ms | 23.6283μs | 42.3221 KOps/s | 42.5106 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 0.4577ms | 52.4640μs | 19.0607 KOps/s | 19.2920 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 76.9000μs | 32.9504μs | 30.3487 KOps/s | 30.1101 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 70.5000μs | 39.7010μs | 25.1883 KOps/s | 25.5554 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 43.0000μs | 23.3009μs | 42.9168 KOps/s | 42.8993 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 0.1755ms | 53.2497μs | 18.7794 KOps/s | 18.6833 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 92.8010μs | 34.8639μs | 28.6830 KOps/s | 29.0744 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 59.9000μs | 40.2576μs | 24.8400 KOps/s | 24.9074 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 57.1000μs | 25.1747μs | 39.7225 KOps/s | 40.3614 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 15.0419ms | 13.2785ms | 75.3098 Ops/s | 72.1436 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 51.3798ms | 41.7116ms | 23.9741 Ops/s | 24.0014 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.3269ms | 0.1936ms | 5.1666 KOps/s | 4.9529 KOps/s | |
test_values[td1_return_estimate-False-False] | 13.2744ms | 12.7625ms | 78.3549 Ops/s | 74.6181 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 49.8292ms | 41.7919ms | 23.9281 Ops/s | 24.1514 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 32.4096ms | 31.7870ms | 31.4594 Ops/s | 30.3299 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 47.1532ms | 41.4111ms | 24.1481 Ops/s | 23.8551 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 11.8588ms | 11.7686ms | 84.9718 Ops/s | 83.8347 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 8.8428ms | 3.3946ms | 294.5863 Ops/s | 297.8490 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 1.3540ms | 0.4608ms | 2.1702 KOps/s | 2.1213 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 58.4290ms | 54.7609ms | 18.2612 Ops/s | 18.0248 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 8.9592ms | 2.8540ms | 350.3910 Ops/s | 353.5988 Ops/s | |
test_dqn_speed | 6.9471ms | 1.8056ms | 553.8220 Ops/s | 540.9023 Ops/s | |
test_ddpg_speed | 18.7691ms | 2.7145ms | 368.3977 Ops/s | 361.7681 Ops/s | |
test_sac_speed | 14.7899ms | 7.7538ms | 128.9692 Ops/s | 125.1345 Ops/s | |
test_redq_speed | 21.9111ms | 15.1499ms | 66.0069 Ops/s | 64.1186 Ops/s | |
test_redq_deprec_speed | 19.6961ms | 12.2972ms | 81.3194 Ops/s | 78.9255 Ops/s | |
test_td3_speed | 10.7938ms | 9.7377ms | 102.6936 Ops/s | 100.6557 Ops/s | |
test_cql_speed | 33.8275ms | 27.6488ms | 36.1679 Ops/s | 38.7518 Ops/s | |
test_a2c_speed | 15.6888ms | 5.0621ms | 197.5446 Ops/s | 195.0564 Ops/s | |
test_ppo_speed | 11.0612ms | 5.3966ms | 185.3023 Ops/s | 176.5066 Ops/s | |
test_reinforce_speed | 10.0408ms | 3.9907ms | 250.5799 Ops/s | 252.1186 Ops/s | |
test_iql_speed | 26.2949ms | 20.3941ms | 49.0338 Ops/s | 47.3095 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 2.9786ms | 2.5590ms | 390.7844 Ops/s | 383.6355 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 4.8742ms | 2.7265ms | 366.7760 Ops/s | 356.8004 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 3.9742ms | 2.7131ms | 368.5878 Ops/s | 362.6206 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 3.3823ms | 2.5622ms | 390.2892 Ops/s | 387.0785 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 4.9763ms | 2.7758ms | 360.2533 Ops/s | 359.9019 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 3.9235ms | 2.7114ms | 368.8103 Ops/s | 358.6469 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 3.4729ms | 2.5530ms | 391.7033 Ops/s | 387.4497 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 0.1291s | 3.1232ms | 320.1878 Ops/s | 362.0930 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 4.6801ms | 2.7012ms | 370.2011 Ops/s | 361.5817 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 2.8214ms | 2.5622ms | 390.2960 Ops/s | 388.4698 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 4.5355ms | 2.7459ms | 364.1742 Ops/s | 358.7391 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 3.9272ms | 2.7127ms | 368.6347 Ops/s | 361.7262 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 3.3370ms | 2.5364ms | 394.2558 Ops/s | 382.2341 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 4.4094ms | 2.7724ms | 360.6920 Ops/s | 360.8414 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 4.6807ms | 2.7112ms | 368.8405 Ops/s | 356.5159 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 3.3811ms | 2.5437ms | 393.1285 Ops/s | 386.6728 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 4.4248ms | 2.7807ms | 359.6161 Ops/s | 358.7148 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 4.3221ms | 2.7051ms | 369.6661 Ops/s | 357.8994 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 0.2597s | 28.9812ms | 34.5051 Ops/s | 34.4302 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 0.1407s | 28.9996ms | 34.4833 Ops/s | 34.5354 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 0.1381s | 26.2357ms | 38.1160 Ops/s | 34.7735 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.1348s | 28.5255ms | 35.0564 Ops/s | 37.4434 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 0.1342s | 26.0943ms | 38.3226 Ops/s | 34.5697 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 0.1356s | 28.5864ms | 34.9817 Ops/s | 37.6791 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.1391s | 23.9238ms | 41.7994 Ops/s | 34.1709 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 0.1322s | 28.1218ms | 35.5596 Ops/s | 37.1617 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 0.1295s | 25.8172ms | 38.7339 Ops/s | 34.3701 Ops/s |
MateuszGuzek
approved these changes
Sep 5, 2023
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, single comment suggesting improvement in the docstring
Co-authored-by: MateuszGuzek <[email protected]>
vmoens
added a commit
that referenced
this pull request
Sep 5, 2023
vmoens
added a commit
to hyerra/rl
that referenced
this pull request
Oct 10, 2023
… gradients (pytorch#1488) Co-authored-by: MateuszGuzek <[email protected]>
vmoens
added a commit
to hyerra/rl
that referenced
this pull request
Oct 10, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
CLA Signed
This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
enhancement
New feature or request
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Implements a masked one hot distribution.
Enables 2 reparam strategies for one-hot samples: RelaxedOneHot or Pass-through
Also renamed "mask" in "action_mask" in the MaskedAction transform.
@matteobettini @MateuszGuzek