Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] torch 2.0 compatibility fix #2475

Merged
merged 1 commit into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions torchrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
except ImportError:
__version__ = None

try:
from torch.compiler import is_dynamo_compiling
except ImportError:
from torch._dynamo import is_compiling as is_dynamo_compiling

_init_extension()

try:
Expand Down Expand Up @@ -69,7 +74,7 @@ def _inv(self):
inv = self._inv()
if inv is None:
inv = _InverseTransform(self)
if not torch.compiler.is_dynamo_compiling():
if not is_dynamo_compiling():
self._inv = weakref.ref(inv)
return inv

Expand All @@ -84,7 +89,7 @@ def _inv(self):
inv = self._inv()
if inv is None:
inv = ComposeTransform([p.inv for p in reversed(self.parts)])
if not torch.compiler.is_dynamo_compiling():
if not is_dynamo_compiling():
self._inv = weakref.ref(inv)
inv._inv = weakref.ref(self)
else:
Expand Down
21 changes: 11 additions & 10 deletions torchrl/modules/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
# speeds up distribution construction
D.Distribution.set_default_validate_args(False)

try:
from torch.compiler import is_dynamo_compiling
except ImportError:
from torch._dynamo import is_compiling as is_dynamo_compiling


class IndependentNormal(D.Independent):
"""Implements a Normal distribution with location scaling.
Expand Down Expand Up @@ -112,7 +117,7 @@ def inv(self):
inv = self._inv()
if inv is None:
inv = _InverseTransform(self)
if not torch.compiler.is_dynamo_compiling():
if not is_dynamo_compiling():
self._inv = weakref.ref(inv)
return inv

Expand Down Expand Up @@ -320,7 +325,7 @@ def inv(self):
inv = self._inv()
if inv is None:
inv = _PatchedComposeTransform([p.inv for p in reversed(self.parts)])
if not torch.compiler.is_dynamo_compiling():
if not is_dynamo_compiling():
self._inv = weakref.ref(inv)
inv._inv = weakref.ref(self)
return inv
Expand All @@ -334,7 +339,7 @@ def inv(self):
inv = self._inv()
if inv is None:
inv = _InverseTransform(self)
if not torch.compiler.is_dynamo_compiling():
if not is_dynamo_compiling():
self._inv = weakref.ref(inv)
return inv

Expand Down Expand Up @@ -432,15 +437,13 @@ def __init__(
self.high = high

if safe_tanh:
if torch.compiler.is_dynamo_compiling():
if is_dynamo_compiling():
_err_compile_safetanh()
t = SafeTanhTransform()
else:
t = D.TanhTransform()
# t = D.TanhTransform()
if torch.compiler.is_dynamo_compiling() or (
self.non_trivial_max or self.non_trivial_min
):
if is_dynamo_compiling() or (self.non_trivial_max or self.non_trivial_min):
t = _PatchedComposeTransform(
[
t,
Expand All @@ -467,9 +470,7 @@ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
if self.tanh_loc:
loc = (loc / self.upscale).tanh() * self.upscale
# loc must be rescaled if tanh_loc
if torch.compiler.is_dynamo_compiling() or (
self.non_trivial_max or self.non_trivial_min
):
if is_dynamo_compiling() or (self.non_trivial_max or self.non_trivial_min):
loc = loc + (self.high - self.low) / 2 + self.low
self.loc = loc
self.scale = scale
Expand Down
7 changes: 6 additions & 1 deletion torchrl/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
raise err_ft from err
from torchrl.envs.utils import step_mdp

try:
from torch.compiler import is_dynamo_compiling
except ImportError:
from torch._dynamo import is_compiling as is_dynamo_compiling

_GAMMA_LMBDA_DEPREC_ERROR = (
"Passing gamma / lambda parameters through the loss constructor "
"is a deprecated feature. To customize your value function, "
Expand Down Expand Up @@ -460,7 +465,7 @@ def _cache_values(func):

@functools.wraps(func)
def new_func(self, netname=None):
if torch.compiler.is_dynamo_compiling():
if is_dynamo_compiling():
if netname is not None:
return func(self, netname)
else:
Expand Down
Loading