Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for MultiDiscrete and MultiBinary action spaces in PPO #30

Merged
merged 11 commits into from
Feb 28, 2024

Conversation

jan1854
Copy link
Collaborator

@jan1854 jan1854 commented Feb 20, 2024

Description

closes #19

Addresses #19. Adds support for MultiDiscrete and MultiBinary action spaces to PPO.

Constructs a multivariate categorical distribution through Tensorflow Probability's Independent and Categorical. Note that the Categorical distribution requires every variable to have the same number of categories. Therefore, I pad the logits to the largest shape across the dimensions (pad by -inf to ensure that these invalid actions have zero probability).

MultiBinary is handled as a special case of MultiDiscrete with two choices per categorical variable.

Only one-dimensional action spaces are supported, so using, e.g., MultiDiscrete([[2],[3]]) or MultiBinary([2, 3]) will result in an exception (as in stable-baselines3).

Testing

I added some tests (tests/test_space, similar to the tests in stable-baselines3) that check if there are errors during learning and that the correct exceptions are raised if PPO is used with multi-dimensional MultiDiscrete and MultiBinary action spaces.

To check whether there are issues with the learning performance, I compared the performance to stable-baselines3's PPO on MultiDiscrete and MultiBinary action space environments. Since there are no environments with these action spaces in the classic Gym benchmarks, I used a discretized action version of Reacher and a binary action version of Acrobot for testing purposes (see the wrappers below).

Test script for MultiDiscrete action spaces:

from datetime import datetime
from typing import Sequence

import gymnasium as gym
import numpy as np

from sbx import PPO


class ActionDiscretizationWrapper(gym.ActionWrapper):
    def __init__(self, env, bins: Sequence[int]):
        super().__init__(env)
        assert isinstance(self.env.action_space, gym.spaces.Box)
        self.action_space = gym.spaces.MultiDiscrete(bins)

    def action(self, action: np.ndarray) -> np.ndarray:
        assert np.all(action < self.action_space.nvec)
        range = self.env.action_space.high - self.env.action_space.low
        cont_action = range * action / (self.action_space.nvec - 1) + self.env.action_space.low
        return cont_action


if __name__ == "__main__":
    env = ActionDiscretizationWrapper(gym.make("Reacher-v4"), [15, 17])

    date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    agent = PPO("MlpPolicy", env, tensorboard_log=f"out/reacher_discrete_{date_time}")
    agent.learn(1000000, progress_bar=True)

Test script for MultiBinary action spaces:

from datetime import datetime

import numpy as np

import gymnasium as gym
from sbx import PPO


class BinaryAcrobotWrapper(gym.ActionWrapper):
    def __init__(self, env):
        super().__init__(env)
        # One action for applying torque -1 (original action: 0), one action for applying torque 1 (original action: 2).
        # If both bits (or none) are set, the torque is 0 (original action: 1).
        self.action_space = gym.spaces.MultiBinary(2)

    def action(self, action: np.ndarray) -> np.ndarray:
        return int(action[1] - action[0] + 1)


if __name__ == "__main__":
    env = BinaryAcrobotWrapper(gym.make("Acrobot-v1"))

    date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    agent = PPO("MlpPolicy", env, tensorboard_log=f"out/binary_acrobot_{date_time}")
    agent.learn(1000000, progress_bar=True)

Results: sbx's and stable-baselines3's PPO have the same learning performance.

reacher_discrete_sbx_vs_sb3

acrobot_binary_sbx_vs_sb3

Motivation and Context

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist:

  • I've read the CONTRIBUTION guide (required)
  • I have updated the changelog accordingly (required).
    (The changelog seems to be in the stable-baselines3 repository, so I would need to create a separate PR for that)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
    (There is no separate documentation for sbx that I could update)
  • I have reformatted the code using make format (required)
  • I have checked the codestyle using make check-codestyle and make lint (required)
  • I have ensured make pytest and make type both pass. (required)
  • I have checked that the documentation builds using make doc (required)

Note: You can run most of the checks using make commit-checks.

Note: we are using a maximum length of 127 characters per line

@araffin araffin self-requested a review February 22, 2024 14:51
@araffin
Copy link
Owner

araffin commented Feb 23, 2024

Hello,
thanks again for the PR =)
I'll try to have a look in the coming days.

Btw, because of your good contributions, would you be interested in becoming a SBX maintainer? (so you won't have to fork the repo for fixing a bug/adding a feature)

@jan1854
Copy link
Collaborator Author

jan1854 commented Feb 23, 2024

Sounds awesome, I'd be happy to become an SBX maintainer :)

@araffin
Copy link
Owner

araffin commented Feb 28, 2024

For built-in multi discrete, I think there are the Atari games?
Although we would need to use the ram version at first until CNN are supported by SBX.

Copy link
Owner

@araffin araffin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks =)

@araffin araffin merged commit db6120b into araffin:master Feb 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] Multi-Discrete action spaces for PPO
2 participants