Skip to content

Annotated implementations of equivariant (graph) neural networks in Jax: EGNN, SEGNN, NequIP.

License

Notifications You must be signed in to change notification settings

smsharma/eqnn-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

3942639 · Feb 7, 2025
May 30, 2024
Dec 5, 2024
Feb 6, 2025
May 30, 2024
May 30, 2024
Jun 2, 2023
May 30, 2024
May 30, 2024
Jun 12, 2024
May 30, 2024

Repository files navigation

E ( 3 ) Equivariant Graph Neural Networks in Jax

License: CC BY 4.0 Run Tests

Implementation of E ( 3 ) equivariant graph neural networks in Jax.

Models

The following equivariant models are implemented:

Additionally, the following non-equivariant models are implemented:

Requirements and tests

To install requirements:

pip install -r requirements.txt

To run tests (testing equivariance and periodic boundary conditions):

cd tests
pytest .

Basic usage and examples

See notebooks/examples.ipynb for example usage of GNN, SEGNN, NequIP, and EGNN.

Cosmological benchmark

The cosmological benchmarking dataset, available in TFRecord format, can be downloaded from Zenodo under the DOI 10.5281/zenodo.11479419. To download the dataset into benchmarks/galaxies/quijote_records, run:

bash benchmarks/galaxies/download_tfrecords.sh

To run the graph-level task:

python benchmarks/galaxies/train_cosmology.py

To run the node-level task:

python benchmarks/galaxies/train_velocities.py

Attribution

See CITATION.cff for citation information. The implementation of SEGNN was partially inspired by segnn-jax.