neural_mass¶
- class JRTheta(A, B, a, b, v0, nu_max, r, J, a_1, a_2, a_3, a_4, mu, I)¶
Bases:
tuple
- A¶
Alias for field number 0
- B¶
Alias for field number 1
- I¶
Alias for field number 13
- J¶
Alias for field number 7
- a¶
Alias for field number 2
- a_1¶
Alias for field number 8
- a_2¶
Alias for field number 9
- a_3¶
Alias for field number 10
- a_4¶
Alias for field number 11
- b¶
Alias for field number 3
- mu¶
Alias for field number 12
- nu_max¶
Alias for field number 5
- r¶
Alias for field number 6
- v0¶
Alias for field number 4
- class JRState(y0, y1, y2, y3, y4, y5)¶
Bases:
tuple
- y0¶
Alias for field number 0
- y1¶
Alias for field number 1
- y2¶
Alias for field number 2
- y3¶
Alias for field number 3
- y4¶
Alias for field number 4
- y5¶
Alias for field number 5
- jr_dfun(ys, c, p)¶
- class BVEPTheta(tau0, I1, x0)¶
Bases:
tuple
- I1¶
Alias for field number 1
- tau0¶
Alias for field number 0
- x0¶
Alias for field number 2
- class MPRTheta(tau, I, Delta, J, eta, cr, cv)¶
Bases:
tuple
- Delta¶
Alias for field number 2
- I¶
Alias for field number 1
- J¶
Alias for field number 3
- cr¶
Alias for field number 5
- cv¶
Alias for field number 6
- eta¶
Alias for field number 4
- tau¶
Alias for field number 0
- mpr_dfun(ys, c, p)¶
- mpr_r_positive(rv, _)¶
- class BOLDTheta(tau_s, tau_f, tau_o, alpha, te, v0, e0, epsilon, nu_0, r_0, recip_tau_s, recip_tau_f, recip_tau_o, recip_alpha, recip_e0, k1, k2, k3)¶
Bases:
tuple
- alpha¶
Alias for field number 3
- e0¶
Alias for field number 6
- epsilon¶
Alias for field number 7
- k1¶
Alias for field number 15
- k2¶
Alias for field number 16
- k3¶
Alias for field number 17
- nu_0¶
Alias for field number 8
- r_0¶
Alias for field number 9
- recip_alpha¶
Alias for field number 13
- recip_e0¶
Alias for field number 14
- recip_tau_f¶
Alias for field number 11
- recip_tau_o¶
Alias for field number 12
- recip_tau_s¶
Alias for field number 10
- tau_f¶
Alias for field number 1
- tau_o¶
Alias for field number 2
- tau_s¶
Alias for field number 0
- te¶
Alias for field number 4
- v0¶
Alias for field number 5
- compute_bold_theta(tau_s=0.65, tau_f=0.41, tau_o=0.98, alpha=0.32, te=0.04, v0=4.0, e0=0.4, epsilon=0.5, nu_0=40.3, r_0=25.0)¶
- class DCMTheta(A, B, C)¶
Bases:
tuple
- A¶
Alias for field number 0
- B¶
Alias for field number 1
- C¶
Alias for field number 2
- dcm_dfun(x, u, p: DCMTheta)¶
Implements the classical bilinear DCM dot{x} = (A + sum_j u_j B_j ) x + C u
- DopaTheta¶
alias of
dopaTheta
- class DopaState(r, V, u, Sa, Sg, Dp)¶
Bases:
tuple
- Dp¶
Alias for field number 5
- Sa¶
Alias for field number 3
- Sg¶
Alias for field number 4
- V¶
Alias for field number 1
- r¶
Alias for field number 0
- u¶
Alias for field number 2
- dopa_dfun(y, cy, p: dopaTheta)¶
Adaptive QIF model with dopamine modulation.
- dopa_net_dfun(y, p)¶
Canonical form for network of dopa nodes.
- dopa_r_positive(y, _)¶
- dopa_gfun_mulr(y, p)¶
Provide a multiplicative r, additive V gfun.
- dopa_gfun_add(y, p)¶
Provides an additive noise gfun.
monitor¶
- make_offline(step_fn, sample_fn, *args)¶
Compute monitor samples in an offline or batch fashion.
- make_timeavg(shape)¶
Make a time average monitor.
- compute_sarvas_gain(q, r, o, att, Ds=0, Dc=0) Array ¶
- make_gain(gain, shape=None)¶
Make a gain-matrix monitor suitable for sEEG, EEG & MEG.
- make_cov(shape)¶
- make_fc(shape)¶
- make_fft(shape, period)¶
loops¶
Functions for building time stepping loops.
- heun_step(x, dfun, dt, *args, add=0, adhoc=None, return_euler=False)¶
Use a Heun scheme to step state with a right hand sides dfun(.) and additional forcing term add.
- make_sde(dt, dfun, gfun, adhoc=None, return_euler=False, unroll=10)¶
Use a stochastic Heun scheme to integrate autonomous stochastic differential equations (SDEs).
- Parameters:
- dtfloat
Time step
- dfunfunction
Function of the form
dfun(x, p)
that computes drift coefficients of the stochastic differential equation.- gfunfunction or float
Function of the form
gfun(x, p)
that computes diffusion coefficients of the stochastic differential equation. If a numerical value is provided, this is used as a constant diffusion coefficient for additive linear SDE.- adhocfunction or None
Function of the form
f(x, p)
that allows making adhoc corrections to states after a step.- return_euler: bool, default False
Return solution with local Euler estimates.
- unroll: int, default 10
Force unrolls the time stepping loop.
- Returns:
- stepfunction
Function of the form
step(x, z_t, p)
that takes one step in time according to the Heun scheme.- loopfunction
Function of the form
loop(x0, zs, p)
that iteratively callsstep
for allz
.
Notes
In both cases, a Jax compatible parameter set
p
is provided, either an array or some pytree compatible structure.Note that the integrator does not sample normally distributed noise, so this must be provided by the user.
>>> import vbjax as vb >>> _, sde = vb.make_sde(1.0, lambda x, p: -x, 0.1) >>> sde(1.0, vb.randn(4), None) Array([ 0.5093468 , 0.30794007, 0.07600437, -0.03876263], dtype=float32)
- make_ode(dt, dfun, adhoc=None)¶
Use a Heun scheme to integrate autonomous ordinary differential equations (ODEs).
- Parameters:
- dtfloat
Time step
- dfunfunction
Function of the form
dfun(x, p)
that computes derivatives of the ordinary differential equations.- adhocfunction or None
Function of the form
f(x, p)
that allows making adhoc corrections to states after a step.
- Returns:
- stepfunction
Function of the form
step(x, t, p)
that takes one step in time according to the Heun scheme.- loopfunction
Function of the form
loop(x0, ts, p)
that iteratively callsstep
for all time stepsts
.
Notes
In both cases, a Jax compatible parameter set
p
is provided, either an array or some pytree compatible structure.>>> import vbjax as vb, jax.numpy as np >>> _, ode = vb.make_ode(1.0, lambda x, p: -x) >>> ode(1.0, np.r_[:4], None) Array([0.5 , 0.25 , 0.125 , 0.0625], dtype=float32, weak_type=True)
- make_dde(dt, nh, dfun, unroll=10, adhoc=None)¶
Invokes make_sdde w/ gfun 0.
- make_sdde(dt, nh, dfun, gfun, unroll=1, zero_delays=False, adhoc=None)¶
Use a stochastic Heun scheme to integrate autonomous stochastic delay differential equations (SDEs).
- Parameters:
- dtfloat
Time step
- nhint
Maximum delay in time steps.
- dfunfunction
Function of the form
dfun(xt, x, t, p)
that computes drift coefficients of the stochastic differential equation.- gfunfunction or float
Function of the form
gfun(x, p)
that computes diffusion coefficients of the stochastic differential equation. If a numerical value is provided, this is used as a constant diffusion coefficient for additive linear SDE.- adhocfunction or None
Function of the form
f(x,p)
that allows making adhoc corrections after each step.
- Returns:
- stepfunction
Function of the form
step((x_t,t), z_t, p)
that takes one step in time according to the Heun scheme.- loopfunction
Function of the form
loop((xs, t), p)
that iteratively callsstep
for eachxs[nh:]
and starting time indext
.
Notes
A Jax compatible parameter set
p
is provided, either an array or some pytree compatible structure.The integrator does not sample normally distributed noise, so this must be provided by the user.
The history buffer passed to the user functions, on the corrector stage of the Heun method, does not contain the predictor stage, for performance reasons, unless
zero_delays
is set toTrue
. A good compromise can be to set all zero delays to dt.
>>> import vbjax as vb, jax.numpy as np >>> _, sdde = vb.make_sdde(1.0, 2, lambda xt, x, t, p: -xt[t-2], 0.0) >>> x,t = sdde(np.ones(6)+10, None) >>> x Array([ 11., 11., 11., 0., -11., -22.], dtype=float32)
- make_continuation(run_chunk, chunk_len, max_lag, n_from, n_svar, stochastic=True)¶
Helper function to lower memory usage for longer simulations with time delays. WIP
Takes a function
run_chunk(buf, params) -> (buf, chunk_states)
and returns another
continue_chunk(buf, params, rng_key) -> (buf, chunk_states)
The continue_chunk function wraps run_chunk and manages moving the latest states to the first part of buf and filling the rest with samples from N(0,1) if required.
connectome¶
Utilities for connectomes.
- make_conn_latent_mvnorm(SCs, nc=10, return_full=False)¶
Make a latent multivariate normal distribution over connectomes.
- Parameters:
- SCs(nconn, n, n) array
Array of connectomes in a given parcellation.
- ncint, optional
Number of components to use for the latent space.
- return_full: bool, optional
Whether or not to return extra information on the SVD.
- Returns:
- u_mean(nc, ), array_like
Mean of distribution in latent space.
- u_cov(nc, nc), array_like
Covariance of distribution in latent space.
- xfmfunction
Maps latent vector to full connectome.
- u_cov(nc, nc) array
Covariance of the multivariate normal. Returned if return_full=True.
- u(nconn, nconn) array
Left singular vectors corresponding to connectomes embedded. Returned if return_full=True.
- s(nconn) array
Singular values. Returned if return_full=True.
- vt(nconn, n*n) array
Right singular vectors. Returned if return_full=True.
- nconfint
Number of confusions induced by dimensionality reduction. Returned if return_full=True.
layers¶
- make_dense_layers(in_dim, latent_dims=[10], out_dim=None, init_scl=0.1, extra_in=0, act_fn=<PjitFunction of <function leaky_relu>>, key=Array([ 0, 42], dtype=uint32))¶
Make a dense neural network with the given latent layer sizes.
sparse¶
- make_spmv(A, is_symmetric=False, use_scipy=False)¶
Make a closure for a general sparse matrix-vector multiplication.
- Parameters:
- Ascipy.sparse.csr_matrix
Constant sparse matrix.
- is_symmetricbool, optional, default False
Whether matrix is symmetric.
- use_scipy: bool, optional, default False
Use scipy.
- Returns:
- spmvfunction
Function implementing spase matrix vector multiply with support for gradients in Jax.
- csr_to_jax_bcoo(A: csr_matrix)¶
Convert CSR format to batched COO format.
- make_sg_spmv(A: csr_matrix, use_pmap=False, sharding: PositionalSharding = None)¶
Make a SpMV kernel w/ generic scatter-gather operations.
util¶
- to_np(x: Array) ndarray ¶
- to_jax(x: ndarray)¶
Move NumPy array to JAX via DLPack.
- tuple_meshgrid(tup)¶
Applies meshgrid to arrays in a named tuple.
- tuple_ravel(tup)¶
Flatten arrays in fields of tuple.
- tuple_shard(tup, n)¶
Shard arrays in fields of tuple.
shtlc¶
SHT based local coupling functions.
- make_lm(lmax: int)¶
- make_grid(nlat, nlon)¶
Create grid for SHT, phi latitude, theta longitude.
- grid_pairwise_distance(theta, phi)¶
Compute pairwise distances on grid. Memory intensive for large grids.
- randn_relaxed(sht)¶
For shtns sht return random spatial array captured by sht.
- kernel_diff(D, l)¶
Compute diffusion kernel, l in shtns order.
- kernel_dist_origin(theta, phi)¶
Compute distance to origin on grid.
- kernel_sh_normalized(sht, k)¶
- kernel_laplace(sht, theta, phi, size)¶
Compute spatial & spectral coefficients for Laplacian kernel.
- kernel_gaussian(sht, theta, phi, size)¶
Compute spatial & spectral coefficients for Gaussian kernel.
- kernel_mexican_hat(sht, theta, phi, size)¶
Compute spatial & spectral coefficients for Mexican hat kernel.
- kernel_conv_prep(sht, k)¶
Prepares evaluated kernel for SHT convolution.
- kernel_estimate_shtns(sht, k, x0)¶
Estimate effective kernel for kernel k with state x0 for shtns object sht.
- make_shtdiff_np(lmax, nlat, nlon, D, return_L=False, np=<module 'numpy' from '/opt/hostedtoolcache/Python/3.12.7/x64/lib/python3.12/site-packages/numpy/__init__.py'>)¶
Construct SHT diff implementation in plain NumPy.
- make_shtdiff(nlat, lmax=None, nlon=None, D=0.0004, return_L=False)¶
Construct SHT diff implementation with Jax.