Welcome to vbjax’s documentation!

Introduction

vbjax is a Jax-based package for working with virtual brain style models.

Installation

Installs with pip install "vbjax", but you can use the source,

git clone https://github.com/ins-amu/vbjax
cd vbjax
pip install .[dev]

The primary additional dependency of vbjax is [JAX](github.com/google/jax), which itself depends only on NumPy, SciPy & opt-einsum, so it should be safe to add to your existing projects.

gee pee you

CUDA

If you have a CUDA-enabled GPU, you install the requisite dependencies like so

pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

M1/M2 🍎

On newer Apple machines w/ M1 or M2 GPUs, JAX supports using the GPU experimentally by installing just two extra packages:

pip install ml-dtypes==0.2.0 jax-metal

About a third of vbjax tests fail due to absence of certain operations like n-dim scatter/gather & FFTs, and it may not be faster because these CPUs already have excellent memory bandwidth & latency hiding.

CUDA 🐳

BUT because GPU software stack versions make aligning stars look like child’s play, container images are available and auto-built w/ [GitHub Actions](.github/workflows/docker-image.yml), so you can use w/ Docker

docker run --rm -it ghcr.io/ins-amu/vbjax:main python3 -c 'import vbjax; print(vbjax.__version__)'

The images are built on Nvidia runtime images, so --gpus all is enough for Jax to discover the GPU(s).

Tutorial

Contents:

Modules