Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Single call to value network in advantages #1256

Merged
merged 3 commits into from
Jun 13, 2023
Merged

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Jun 12, 2023

This PR allows advantages to call once and only once the value model.
If adv.shifted is set to True and if the params at t and t+1 match, the value net is called only once.
In all other cases, vmap is used to batch the calls to the value net.

cc @tcbegley @apbard

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 12, 2023
@vmoens vmoens added the enhancement New feature or request label Jun 12, 2023
@vmoens vmoens merged commit fccad08 into main Jun 13, 2023
@vmoens vmoens deleted the single_call_adv branch June 13, 2023 09:40
Comment on lines +1118 to +1145
# kwargs = {}
# if self.is_stateless and params is None:
# raise RuntimeError(
# "Expected params to be passed to advantage module but got none."
# )
# if params is not None:
# kwargs["params"] = params
#
# if self.value_network is not None:
# with hold_out_net(self.value_network):
# # we may still need to pass gradient, but we don't want to assign grads to
# # value net params
# self.value_network(tensordict, **kwargs)
#
# value = tensordict.get(self.tensor_keys.value)
#
# step_td = step_mdp(tensordict)
# if target_params is not None:
# # we assume that target parameters are not differentiable
# kwargs["params"] = target_params
# elif "params" in kwargs:
# kwargs["params"] = kwargs["params"].detach()
# if self.value_network is not None:
# with hold_out_net(self.value_network):
# # we may still need to pass gradient, but we don't want to assign grads to
# # value net params
# self.value_network(step_td, **kwargs)
# next_value = step_td.get(self.tensor_keys.value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the commented-out code left on purpose?

@vmoens vmoens restored the single_call_adv branch June 13, 2023 10:27
vmoens added a commit that referenced this pull request Jun 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants