neural_mass

class JRTheta(A, B, a, b, v0, nu_max, r, J, a_1, a_2, a_3, a_4, mu, I)

Bases: tuple

A

Alias for field number 0

B

Alias for field number 1

I

Alias for field number 13

J

Alias for field number 7

a

Alias for field number 2

a_1

Alias for field number 8

a_2

Alias for field number 9

a_3

Alias for field number 10

a_4

Alias for field number 11

b

Alias for field number 3

mu

Alias for field number 12

nu_max

Alias for field number 5

r

Alias for field number 6

v0

Alias for field number 4

class JRState(y0, y1, y2, y3, y4, y5)

Bases: tuple

y0

Alias for field number 0

y1

Alias for field number 1

y2

Alias for field number 2

y3

Alias for field number 3

y4

Alias for field number 4

y5

Alias for field number 5

jr_dfun(ys, c, p)
class BVEPTheta(tau0, I1, x0)

Bases: tuple

I1

Alias for field number 1

tau0

Alias for field number 0

x0

Alias for field number 2

bvep_dfun(ys, c, p: BVEPTheta)
class MPRTheta(tau, I, Delta, J, eta, cr, cv)

Bases: tuple

Delta

Alias for field number 2

I

Alias for field number 1

J

Alias for field number 3

cr

Alias for field number 5

cv

Alias for field number 6

eta

Alias for field number 4

tau

Alias for field number 0

class MPRState(r, V)

Bases: tuple

V

Alias for field number 1

r

Alias for field number 0

mpr_dfun(ys, c, p)
mpr_r_positive(rv, _)
class BOLDTheta(tau_s, tau_f, tau_o, alpha, te, v0, e0, epsilon, nu_0, r_0, recip_tau_s, recip_tau_f, recip_tau_o, recip_alpha, recip_e0, k1, k2, k3)

Bases: tuple

alpha

Alias for field number 3

e0

Alias for field number 6

epsilon

Alias for field number 7

k1

Alias for field number 15

k2

Alias for field number 16

k3

Alias for field number 17

nu_0

Alias for field number 8

r_0

Alias for field number 9

recip_alpha

Alias for field number 13

recip_e0

Alias for field number 14

recip_tau_f

Alias for field number 11

recip_tau_o

Alias for field number 12

recip_tau_s

Alias for field number 10

tau_f

Alias for field number 1

tau_o

Alias for field number 2

tau_s

Alias for field number 0

te

Alias for field number 4

v0

Alias for field number 5

compute_bold_theta(tau_s=0.65, tau_f=0.41, tau_o=0.98, alpha=0.32, te=0.04, v0=4.0, e0=0.4, epsilon=0.5, nu_0=40.3, r_0=25.0)
bold_dfun(sfvq, x, p: BOLDTheta)
class DCMTheta(A, B, C)

Bases: tuple

A

Alias for field number 0

B

Alias for field number 1

C

Alias for field number 2

dcm_dfun(x, u, p: DCMTheta)

Implements the classical bilinear DCM dot{x} = (A + sum_j u_j B_j ) x + C u

DopaTheta

alias of dopaTheta

class DopaState(r, V, u, Sa, Sg, Dp)

Bases: tuple

Dp

Alias for field number 5

Sa

Alias for field number 3

Sg

Alias for field number 4

V

Alias for field number 1

r

Alias for field number 0

u

Alias for field number 2

dopa_dfun(y, cy, p: dopaTheta)

Adaptive QIF model with dopamine modulation.

dopa_net_dfun(y, p)

Canonical form for network of dopa nodes.

dopa_r_positive(y, _)
dopa_gfun_mulr(y, p)

Provide a multiplicative r, additive V gfun.

dopa_gfun_add(y, p)

Provides an additive noise gfun.

monitor

make_offline(step_fn, sample_fn, *args)

Compute monitor samples in an offline or batch fashion.

make_timeavg(shape)

Make a time average monitor.

compute_sarvas_gain(q, r, o, att, Ds=0, Dc=0) Array
make_gain(gain, shape=None)

Make a gain-matrix monitor suitable for sEEG, EEG & MEG.

make_bold(shape, dt, p: BOLDTheta)

Make a BOLD fMRI monitor.

make_cov(shape)
make_fc(shape)
make_fft(shape, period)

loops

Functions for building time stepping loops.

heun_step(x, dfun, dt, *args, add=0, adhoc=None, return_euler=False)

Use a Heun scheme to step state with a right hand sides dfun(.) and additional forcing term add.

make_sde(dt, dfun, gfun, adhoc=None, return_euler=False, unroll=10)

Use a stochastic Heun scheme to integrate autonomous stochastic differential equations (SDEs).

Parameters:
dtfloat

Time step

dfunfunction

Function of the form dfun(x, p) that computes drift coefficients of the stochastic differential equation.

gfunfunction or float

Function of the form gfun(x, p) that computes diffusion coefficients of the stochastic differential equation. If a numerical value is provided, this is used as a constant diffusion coefficient for additive linear SDE.

adhocfunction or None

Function of the form f(x, p) that allows making adhoc corrections to states after a step.

return_euler: bool, default False

Return solution with local Euler estimates.

unroll: int, default 10

Force unrolls the time stepping loop.

Returns:
stepfunction

Function of the form step(x, z_t, p) that takes one step in time according to the Heun scheme.

loopfunction

Function of the form loop(x0, zs, p) that iteratively calls step for all z.

Notes

In both cases, a Jax compatible parameter set p is provided, either an array or some pytree compatible structure.

Note that the integrator does not sample normally distributed noise, so this must be provided by the user.

>>> import vbjax as vb
>>> _, sde = vb.make_sde(1.0, lambda x, p: -x, 0.1)
>>> sde(1.0, vb.randn(4), None)
Array([ 0.5093468 ,  0.30794007,  0.07600437, -0.03876263], dtype=float32)
make_ode(dt, dfun, adhoc=None)

Use a Heun scheme to integrate autonomous ordinary differential equations (ODEs).

Parameters:
dtfloat

Time step

dfunfunction

Function of the form dfun(x, p) that computes derivatives of the ordinary differential equations.

adhocfunction or None

Function of the form f(x, p) that allows making adhoc corrections to states after a step.

Returns:
stepfunction

Function of the form step(x, t, p) that takes one step in time according to the Heun scheme.

loopfunction

Function of the form loop(x0, ts, p) that iteratively calls step for all time steps ts.

Notes

In both cases, a Jax compatible parameter set p is provided, either an array or some pytree compatible structure.

>>> import vbjax as vb, jax.numpy as np
>>> _, ode = vb.make_ode(1.0, lambda x, p: -x)
>>> ode(1.0, np.r_[:4], None)
Array([0.5   , 0.25  , 0.125 , 0.0625], dtype=float32, weak_type=True)
make_dde(dt, nh, dfun, unroll=10, adhoc=None)

Invokes make_sdde w/ gfun 0.

make_sdde(dt, nh, dfun, gfun, unroll=1, zero_delays=False, adhoc=None)

Use a stochastic Heun scheme to integrate autonomous stochastic delay differential equations (SDEs).

Parameters:
dtfloat

Time step

nhint

Maximum delay in time steps.

dfunfunction

Function of the form dfun(xt, x, t, p) that computes drift coefficients of the stochastic differential equation.

gfunfunction or float

Function of the form gfun(x, p) that computes diffusion coefficients of the stochastic differential equation. If a numerical value is provided, this is used as a constant diffusion coefficient for additive linear SDE.

adhocfunction or None

Function of the form f(x,p) that allows making adhoc corrections after each step.

Returns:
stepfunction

Function of the form step((x_t,t), z_t, p) that takes one step in time according to the Heun scheme.

loopfunction

Function of the form loop((xs, t), p) that iteratively calls step for each xs[nh:] and starting time index t.

Notes

  • A Jax compatible parameter set p is provided, either an array or some pytree compatible structure.

  • The integrator does not sample normally distributed noise, so this must be provided by the user.

  • The history buffer passed to the user functions, on the corrector stage of the Heun method, does not contain the predictor stage, for performance reasons, unless zero_delays is set to True. A good compromise can be to set all zero delays to dt.

>>> import vbjax as vb, jax.numpy as np
>>> _, sdde = vb.make_sdde(1.0, 2, lambda xt, x, t, p: -xt[t-2], 0.0)
>>> x,t = sdde(np.ones(6)+10, None)
>>> x
Array([ 11.,  11.,  11.,   0., -11., -22.], dtype=float32)
make_continuation(run_chunk, chunk_len, max_lag, n_from, n_svar, stochastic=True)

Helper function to lower memory usage for longer simulations with time delays. WIP

Takes a function

run_chunk(buf, params) -> (buf, chunk_states)

and returns another

continue_chunk(buf, params, rng_key) -> (buf, chunk_states)

The continue_chunk function wraps run_chunk and manages moving the latest states to the first part of buf and filling the rest with samples from N(0,1) if required.

connectome

Utilities for connectomes.

make_conn_latent_mvnorm(SCs, nc=10, return_full=False)

Make a latent multivariate normal distribution over connectomes.

Parameters:
SCs(nconn, n, n) array

Array of connectomes in a given parcellation.

ncint, optional

Number of components to use for the latent space.

return_full: bool, optional

Whether or not to return extra information on the SVD.

Returns:
u_mean(nc, ), array_like

Mean of distribution in latent space.

u_cov(nc, nc), array_like

Covariance of distribution in latent space.

xfmfunction

Maps latent vector to full connectome.

u_cov(nc, nc) array

Covariance of the multivariate normal. Returned if return_full=True.

u(nconn, nconn) array

Left singular vectors corresponding to connectomes embedded. Returned if return_full=True.

s(nconn) array

Singular values. Returned if return_full=True.

vt(nconn, n*n) array

Right singular vectors. Returned if return_full=True.

nconfint

Number of confusions induced by dimensionality reduction. Returned if return_full=True.

layers

make_dense_layers(in_dim, latent_dims=[10], out_dim=None, init_scl=0.1, extra_in=0, act_fn=<PjitFunction of <function leaky_relu>>, key=Array([ 0, 42], dtype=uint32))

Make a dense neural network with the given latent layer sizes.

sparse

make_spmv(A, is_symmetric=False, use_scipy=False)

Make a closure for a general sparse matrix-vector multiplication.

Parameters:
Ascipy.sparse.csr_matrix

Constant sparse matrix.

is_symmetricbool, optional, default False

Whether matrix is symmetric.

use_scipy: bool, optional, default False

Use scipy.

Returns:
spmvfunction

Function implementing spase matrix vector multiply with support for gradients in Jax.

csr_to_jax_bcoo(A: csr_matrix)

Convert CSR format to batched COO format.

make_sg_spmv(A: csr_matrix, use_pmap=False, sharding: PositionalSharding = None)

Make a SpMV kernel w/ generic scatter-gather operations.

util

to_np(x: Array) ndarray
to_jax(x: ndarray)

Move NumPy array to JAX via DLPack.

tuple_meshgrid(tup)

Applies meshgrid to arrays in a named tuple.

tuple_ravel(tup)

Flatten arrays in fields of tuple.

tuple_shard(tup, n)

Shard arrays in fields of tuple.

shtlc

SHT based local coupling functions.

make_grid_shtns(lmax, nlat, nlon, D)

Create shtns object and grid as in make_grid.

make_lm(lmax: int)
make_grid(nlat, nlon)

Create grid for SHT, phi latitude, theta longitude.

grid_pairwise_distance(theta, phi)

Compute pairwise distances on grid. Memory intensive for large grids.

randn_relaxed(sht)

For shtns sht return random spatial array captured by sht.

kernel_diff(D, l)

Compute diffusion kernel, l in shtns order.

kernel_dist_origin(theta, phi)

Compute distance to origin on grid.

kernel_sh_normalized(sht, k)
kernel_laplace(sht, theta, phi, size)

Compute spatial & spectral coefficients for Laplacian kernel.

kernel_gaussian(sht, theta, phi, size)

Compute spatial & spectral coefficients for Gaussian kernel.

kernel_mexican_hat(sht, theta, phi, size)

Compute spatial & spectral coefficients for Mexican hat kernel.

kernel_conv_prep(sht, k)

Prepares evaluated kernel for SHT convolution.

kernel_estimate_shtns(sht, k, x0)

Estimate effective kernel for kernel k with state x0 for shtns object sht.

make_shtdiff_np(lmax, nlat, nlon, D, return_L=False, np=<module 'numpy' from '/opt/hostedtoolcache/Python/3.12.4/x64/lib/python3.12/site-packages/numpy/__init__.py'>)

Construct SHT diff implementation in plain NumPy.

make_shtdiff(nlat, lmax=None, nlon=None, D=0.0004, return_L=False)

Construct SHT diff implementation with Jax.