A relatively carefully written library in JAX to support my own research (and hopefully help otherstoo).
I do not intend to make it a framework that satisfies every model type and everybody (unless it gets viral which is unlikely). But anybody is free to contribute (so far just myself though).
It contains two things in a separate fashion:
- Model architectures (
jaxml.models
) - Inference engine (
jaxml.inference_engine
)
Within the definition of model architectures, it also uses the following
- Neural network components (
jaxml.nn
) - Model configs (
jaxml.config
)
Currently support:
- Llama
Inference engine features:
- tensor parallel and data parallel (using JAX sharding semantics)
- AOT-compile for prefilling function and decoding function, and cache them!
- Allow JAX-flash-attention (
jax.experimental.pallas.ops.flash_attention
)