Main scripts | Typical train and test flow | Citation
algorithm/marl_ppo.py
for training Multi agent PPO on target MPE environment.- Note run this script as python module with
python -m algorithm/marl_ppo.py
for imports to work properly.
- Note run this script as python module with
envs/target_mpe_env.py
. This is the main class that defines the target MPE environment.- Also look at
envs/wrapper.py
for env wrappers.
- Also look at
config/mappo_config.py
. This is the one and only file for changing config values to run experiments. Used python classes instead of yaml file to get auto complete and type checking and easier refactor when accessing and changing the structure of config.visualize_actor.py
for visualizing the trained actor in a local environment.model/actor_critic_rnn.py
has all the flax linen networks used in the PPO.
- Run the
train_with_gpu.ipynb
notebook in a colab with gpu.- Remember to set up the config in
WandbConfig
inconfig/mappo_config.py
and change modeonline
to get wandb logging. - The artifacts are saved under the name "PPO_RNN_Runner_State"
- Remember to set up the config in
- Visualize the actor with
visualize_actor.py
after changing theartifact_version
variable in the block.if __name__ == "__main__"
It is recommended to first install either requirements_jax_cpu.txt
or requirements_jax_cuda.txt
before
requirements.txt
since the packages in requirements
will install a jax version for you.
If you use JaxInforMARL in your work, please cite as follows:
@software{JaxInforMARL,
title={JaxInforMARL: Multi-Agent Target MPE RL Environments with GNNs in JAX},
author={Joseph Selvaraaj},
year = {2025},
url = {https://github.com/jselvaraaj/JaxInforMARL},
version = {1.0.0}
}