@@ -2,26 +2,128 @@
from typing import Union, Callable
from jax.lax import stop_gradient
import brainpy.math as bm
from brainpy.dyn.base import NeuGroup
from brainpy.initialize import ZeroInit, OneInit, Initializer, init_param
from brainpy.initialize import (ZeroInit, OneInit, Initializer,
parameter, variable, noise as init_noise)
from brainpy.integrators import sdeint, odeint, JointEq
from brainpy.tools.checking import check_initializer
from brainpy.types import Shape, Tensor
from brainpy.modes import Mode, NormalMode, BatchingMode, TrainingMode, normal, check
from brainpy.tools.checking import check_initializer, check_callable
from brainpy.types import Shape, Array
__all__ = [
'LeakyIntegrator',
'LIF',
'ExpIF',
'AdExIF',
'QuaIF',
'AdQuaIF',
'GIF',
'ALIFBellec2020',
'Izhikevich',
'HindmarshRose',
'FHN',
]
class LeakyIntegrator(NeuGroup):
r"""Leaky Integrator Model.
**Model Descriptions**
This class implements a leaky integrator model, in which its dynamics is
given by:
.. math::
\tau \frac{dV}{dt} = - (V(t) - V_{rest}) + RI(t)
where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting
membrane potential, :math:`\tau` is the time constant, and :math:`R` is the
resistance.
Parameters
----------
size: sequence of int, int
The size of the neuron group.
V_rest: float, JaxArray, ndarray, Initializer, callable
Resting membrane potential.
R: float, JaxArray, ndarray, Initializer, callable
Membrane resistance.
tau: float, JaxArray, ndarray, Initializer, callable
Membrane time constant.
V_initializer: JaxArray, ndarray, Initializer, callable
The initializer of membrane potential.
noise: JaxArray, ndarray, Initializer, callable
The noise added onto the membrane potential
method: str
The numerical integration method.
name: str
The group name.
"""
def __init__(
self,
# neuron group size
size: Shape,
keep_size: bool = False,
# neuron parameters
V_rest: Union[float, Array, Initializer, Callable] = 0.,
R: Union[float, Array, Initializer, Callable] = 1.,
tau: Union[float, Array, Initializer, Callable] = 10.,
V_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
noise: Union[float, Array, Initializer, Callable] = None,
# other parameter
name: str = None,
mode: Mode = normal,
method: str = 'exp_auto',
):
super(LeakyIntegrator, self).__init__(size=size,
mode=mode,
keep_size=keep_size,
name=name)
check(self.mode, (TrainingMode, NormalMode), self.__class__)
# parameters
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
self.tau = parameter(tau, self.varshape, allow_none=False)
self.R = parameter(R, self.varshape, allow_none=False)
self.noise = init_noise(noise, self.varshape)
# initializers
check_initializer(V_initializer, 'V_initializer')
self._V_initializer = V_initializer
# variables
self.V = variable(self._V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
# integral
if self.noise is None:
self.integral = odeint(method=method, f=self.derivative)
else:
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)
def derivative(self, V, t, I_ext):
return (-V + self.V_rest + self.R * I_ext) / self.tau
def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
def update(self, tdi, x=None):
if x is not None: self.input += x
self.V.value = self.integral(self.V.value, tdi.t, self.input.value, tdi.dt)
def clear_input(self):
self.input[:] = 0.
class LIF(NeuGroup):
r"""Leaky integrate-and-fire neuron model.
@@ -31,7 +133,7 @@ class LIF(NeuGroup):
.. math::
\tau \frac{dV}{dt} = - (V(t) - V_{rest}) + I(t) \\
\tau \frac{dV}{dt} = - (V(t) - V_{rest}) + R I(t) \\
\text{after} \quad V(t) \gt V_{th}, V(t) = V_{reset} \quad
\text{last} \quad \tau_{ref} \quad \text{ms}
@@ -56,6 +158,8 @@ class LIF(NeuGroup):
Reset potential after spike.
V_th: float, JaxArray, ndarray, Initializer, callable
Threshold potential of spike.
R: float, JaxArray, ndarray, Initializer, callable
Membrane resistance.
tau: float, JaxArray, ndarray, Initializer, callable
Membrane time constant.
tau_ref: float, JaxArray, ndarray, Initializer, callable
@@ -64,8 +168,6 @@ class LIF(NeuGroup):
The initializer of membrane potential.
noise: JaxArray, ndarray, Initializer, callable
The noise added onto the membrane potential
noise_type: str
The type of the provided noise. Can be `value` or `func`.
method: str
The numerical integration method.
name: str
@@ -81,72 +183,118 @@ class LIF(NeuGroup):
def __init__(
self,
size: Shape,
V_rest: Union[float, Tensor, Initializer, Callable] = 0.,
V_reset: Union[float, Tensor, Initializer, Callable] = -5.,
V_th: Union[float, Tensor, Initializer, Callable] = 20.,
tau: Union[float, Tensor, Initializer, Callable] = 10.,
tau_ref: Union[float, Tensor, Initializer, Callable] = 1.,
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
noise: Union[float, Tensor, Initializer, Callable] = None,
noise_type: str = 'value',
keep_size: bool=False,
keep_size: bool = False,
# other parameter
V_rest: Union[float, Array, Initializer, Callable] = 0.,
V_reset: Union[float, Array, Initializer, Callable] = -5.,
V_th: Union[float, Array, Initializer, Callable] = 20.,
R: Union[float, Array, Initializer, Callable] = 1.,
tau: Union[float, Array, Initializer, Callable] = 10.,
tau_ref: Union[float, Array, Initializer, Callable] = None,
V_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
noise: Union[float, Array, Initializer, Callable] = None,
method: str = 'exp_auto',
name: str = None
name: str = None,
# training parameter
mode: Mode = normal,
spike_fun: Callable = bm.spike_with_sigmoid_grad,
):
# initialization
super(LIF, self).__init__(size=size, name=name)
super(LIF, self).__init__(size=size,
name=name,
keep_size=keep_size,
mode=mode)
check(self.mode, (TrainingMode, NormalMode), self.__class__)
# parameters
self.keep_size = keep_size
self.noise_type = noise_type
if noise_type not in ['func', 'value']:
raise ValueError(f'noise_type only supports `func` and `value`, but we got {noise_type}')
size = self.size if keep_size else self.num
self.V_rest = init_param(V_rest, size, allow_none=False)
self.V_reset = init_param(V_reset, size, allow_none=False)
self.V_th = init_param(V_th, size, allow_none=False)
self.tau = init_param(tau, size, allow_none=False)
self.tau_ref = init_param(tau_ref, size, allow_none=False)
if noise_type == 'func':
self.noise = noise
else:
self.noise = init_param(noise, size, allow_none=True)
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
self.V_reset = parameter(V_reset, self.varshape, allow_none=False)
self.V_th = parameter(V_th, self.varshape, allow_none=False)
self.tau = parameter(tau, self.varshape, allow_none=False)
self.R = parameter(R, self.varshape, allow_none=False)
self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True)
self.noise = init_noise(noise, self.varshape)
self.spike_fun = check_callable(spike_fun, 'spike_fun')
# initializers
check_initializer(V_initializer, 'V_initializer')
self._V_initializer = V_initializer
# variables
self.V = bm.Variable(init_param(V_initializer, size))
self.input = bm.Variable(bm.zeros(size))
self.spike = bm.Variable(bm.zeros(size, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(size) * -1e7)
self.refractory = bm.Variable(bm.zeros(size, dtype=bool))
self.V = variable(self._V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
sp_type = bm.dftype() if isinstance(mode, TrainingMode) else bool # the gradient of spike is a float
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
if self.tau_ref is not None:
self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, mode, self.varshape)
self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
# integral
f = lambda V, t, I_ext: (-V + self.V_rest + I_ext) / self.tau
if self.noise is not None:
g = noise if (noise_type == 'func') else (lambda V, t, I_ext: self.noise / bm.sqrt(self.tau))
self.integral = sdeint(method=method, f=f, g=g)
if self.noise is None:
self.integral = odeint(method=method, f=self.derivative)
else:
self.integral = odeint(method=method, f=f)
def reset(self):
self.V.value = init_param(self._V_initializer, self.size if self.keep_size else self.num)
self.input[:] = 0
self.spike[:] = False
self.t_last_spike[:] = -1e7
self.refractory[:] = False
def update(self, t, dt):
refractory = (t - self.t_last_spike) <= self.tau_ref
V = self.integral(self.V, t, self.input, dt=dt)
V = bm.where(refractory, self.V, V)
spike = V >= self.V_th
self.t_last_spike.value = bm.where(spike, t, self.t_last_spike)
self.V.value = bm.where(spike, self.V_reset, V)
self.refractory.value = bm.logical_or(refractory, spike)
self.spike.value = spike
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)
def derivative(self, V, t, I_ext):
return (-V + self.V_rest + self.R * I_ext) / self.tau
def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
if self.tau_ref is not None:
self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape)
self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
def update(self, tdi, x=None):
t, dt = tdi.t, tdi.dt
if x is not None: self.input += x
# integrate membrane potential
V = self.integral(self.V.value, t, self.input.value, dt)
if self.tau_ref is not None:
# refractory
refractory = (t - self.t_last_spike) <= self.tau_ref
if isinstance(self.mode, TrainingMode):
refractory = stop_gradient(refractory)
V = bm.where(refractory, self.V, V)
# spike, refractory, spiking time, and membrane potential reset
if isinstance(self.mode, TrainingMode):
spike = self.spike_fun(V - self.V_th)
spike_no_grad = stop_gradient(spike)
V += (self.V_reset - V) * spike_no_grad
spike_ = spike_no_grad > 0.
# will be used in other place, like Delta Synapse, so stop its gradient
refractory = stop_gradient(bm.logical_or(refractory, spike_).value)
t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike).value)
else:
spike = V >= self.V_th
V = bm.where(spike, self.V_reset, V)
refractory = bm.logical_or(refractory, spike)
t_last_spike = bm.where(spike, t, self.t_last_spike)
self.V.value = V
self.spike.value = spike
self.refractory.value = refractory
self.t_last_spike.value = t_last_spike
else:
# spike, spiking time, and membrane potential reset
if isinstance(self.mode, TrainingMode):
spike = self.spike_fun(V - self.V_th)
spike_no_grad = stop_gradient(spike)
V += (self.V_reset - V) * spike_no_grad
else:
spike = V >= self.V_th
V = bm.where(spike, self.V_reset, V)
self.V.value = V
self.spike.value = spike
def clear_input(self):
self.input[:] = 0.
@@ -251,66 +399,94 @@ class ExpIF(NeuGroup):
def __init__(
self,
size: Shape,
V_rest: Union[float, Tensor, Initializer, Callable] = -65.,
V_reset: Union[float, Tensor, Initializer, Callable] = -68.,
V_th: Union[float, Tensor, Initializer, Callable] = -30.,
V_T: Union[float, Tensor, Initializer, Callable] = -59.9,
delta_T: Union[float, Tensor, Initializer, Callable] = 3.48,
R: Union[float, Tensor, Initializer, Callable] = 1.,
tau: Union[float, Tensor, Initializer, Callable] = 10.,
tau_ref: Union[float, Tensor, Initializer, Callable] = 1.7,
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
V_rest: Union[float, Array, Initializer, Callable] = -65.,
V_reset: Union[float, Array, Initializer, Callable] = -68.,
V_th: Union[float, Array, Initializer, Callable] = -30.,
V_T: Union[float, Array, Initializer, Callable] = -59.9,
delta_T: Union[float, Array, Initializer, Callable] = 3.48,
R: Union[float, Array, Initializer, Callable] = 1.,
tau: Union[float, Array, Initializer, Callable] = 10.,
tau_ref: Union[float, Array, Initializer, Callable] = None,
V_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
noise: Union[float, Array, Initializer, Callable] = None,
keep_size: bool = False,
mode: Mode = normal,
method: str = 'exp_auto',
name: str = None
):
# initialize
super(ExpIF, self).__init__(size=size, name=name)
super(ExpIF, self).__init__(size=size,
name=name,
mode=mode,
keep_size=keep_size, )
check(self.mode, (TrainingMode, NormalMode), self.__class__)
# parameters
self.V_rest = init_param(V_rest, self.num, allow_none=False)
self.V_reset = init_param(V_reset, self.num, allow_none=False)
self.V_th = init_param(V_th, self.num, allow_none=False)
self.V_T = init_param(V_T, self.num, allow_none=False)
self.delta_T = init_param(delta_T, self.num, allow_none=False)
self.tau_ref = init_param(tau_ref, self.num, allow_none=False)
self.tau = init_param(tau, self.num, allow_none=False)
self.R = init_param(R, self.num, allow_none=False)
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
self.V_reset = parameter(V_reset, self.varshape, allow_none=False)
self.V_th = parameter(V_th, self.varshape, allow_none=False)
self.V_T = parameter(V_T, self.varshape, allow_none=False)
self.delta_T = parameter(delta_T, self.varshape, allow_none=False)
self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True)
self.tau = parameter(tau, self.varshape, allow_none=False)
self.R = parameter(R, self.varshape, allow_none=False)
self.noise = init_noise(noise, self.varshape)
# initializers
check_initializer(V_initializer, 'V_initializer')
self._V_initializer = V_initializer
# variables
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)
self.V = variable(V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, mode, self.varshape)
if self.tau_ref is not None:
self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
# integral
self.integral = odeint(method=method, f=self.derivative)
if self.noise is None:
self.integral = odeint(method=method, f=self.derivative)
else:
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)
def reset(self):
self.V.value = init_param(self._V_initializer, (self.num,))
self.input[:] = 0
self.spike[:] = False
self.t_last_spike[:] = -1e7
self.refractory[:] = False
def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape)
if self.tau_ref is not None:
self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
def derivative(self, V, t, I_ext):
exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T)
dvdt = (- (V - self.V_rest) + exp_v + self.R * I_ext) / self.tau
return dvdt
def update(self, t, dt):
refractory = (t - self.t_last_spike) <= self.tau_ref
V = self.integral(self.V, t, self.input, dt=dt)
V = bm.where(refractory, self.V, V)
spike = self.V_th <= V
self.t_last_spike.value = bm.where(spike, t, self.t_last_spike)
self.V.value = bm.where(spike, self.V_reset, V)
self.refractory.value = bm.logical_or(refractory, spike)
def update(self, tdi, x=None):
t, dt = tdi.t, tdi.dt
if x is not None: self.input += x
V = self.integral(self.V.value, t, self.input.value, dt)
if self.tau_ref is not None:
refractory = (t - self.t_last_spike) <= self.tau_ref
V = bm.where(refractory, self.V, V)
spike = self.V_th <= V
t_last_spike = bm.where(spike, t, self.t_last_spike)
V = bm.where(spike, self.V_reset, V)
self.refractory.value = bm.logical_or(refractory, spike)
else:
spike = self.V_th <= V
t_last_spike = bm.where(spike, t, self.t_last_spike)
V = bm.where(spike, self.V_reset, V)
self.V.value = V
self.spike.value = spike
self.t_last_spike.value = t_last_spike
def clear_input(self):
self.input[:] = 0.
@@ -390,34 +566,42 @@ class AdExIF(NeuGroup):
def __init__(
self,
size: Shape,
V_rest: Union[float, Tensor, Initializer, Callable] = -65.,
V_reset: Union[float, Tensor, Initializer, Callable] = -68.,
V_th: Union[float, Tensor, Initializer, Callable] = -30.,
V_T: Union[float, Tensor, Initializer, Callable] = -59.9,
delta_T: Union[float, Tensor, Initializer, Callable] = 3.48,
a: Union[float, Tensor, Initializer, Callable] = 1.,
b: Union[float, Tensor, Initializer, Callable] = 1.,
tau: Union[float, Tensor, Initializer, Callable] = 10.,
tau_w: Union[float, Tensor, Initializer, Callable] = 30.,
R: Union[float, Tensor, Initializer, Callable] = 1.,
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
w_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
V_rest: Union[float, Array, Initializer, Callable] = -65.,
V_reset: Union[float, Array, Initializer, Callable] = -68.,
V_th: Union[float, Array, Initializer, Callable] = -30.,
V_T: Union[float, Array, Initializer, Callable] = -59.9,
delta_T: Union[float, Array, Initializer, Callable] = 3.48,
a: Union[float, Array, Initializer, Callable] = 1.,
b: Union[float, Array, Initializer, Callable] = 1.,
tau: Union[float, Array, Initializer, Callable] = 10.,
tau_w: Union[float, Array, Initializer, Callable] = 30.,
R: Union[float, Array, Initializer, Callable] = 1.,
V_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
w_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
noise: Union[float, Array, Initializer, Callable] = None,
method: str = 'exp_auto',
keep_size: bool = False,
mode: Mode = normal,
name: str = None
):
super(AdExIF, self).__init__(size=size, name=name)
super(AdExIF, self).__init__(size=size,
keep_size=keep_size,
name=name,
mode=mode, )
check(self.mode, (TrainingMode, NormalMode), self.__class__)
# parameters
self.V_rest = init_param(V_rest, self.num, allow_none=False)
self.V_reset = init_param(V_reset, self.num, allow_none=False)
self.V_th = init_param(V_th, self.num, allow_none=False)
self.V_T = init_param(V_T, self.num, allow_none=False)
self.delta_T = init_param(delta_T, self.num, allow_none=False)
self.a = init_param(a, self.num, allow_none=False)
self.b = init_param(b, self.num, allow_none=False)
self.tau = init_param(tau, self.num, allow_none=False)
self.tau_w = init_param(tau_w, self.num, allow_none=False)
self.R = init_param(R, self.num, allow_none=False)
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
self.V_reset = parameter(V_reset, self.varshape, allow_none=False)
self.V_th = parameter(V_th, self.varshape, allow_none=False)
self.V_T = parameter(V_T, self.varshape, allow_none=False)
self.delta_T = parameter(delta_T, self.varshape, allow_none=False)
self.a = parameter(a, self.varshape, allow_none=False)
self.b = parameter(b, self.varshape, allow_none=False)
self.tau = parameter(tau, self.varshape, allow_none=False)
self.tau_w = parameter(tau_w, self.varshape, allow_none=False)
self.R = parameter(R, self.varshape, allow_none=False)
self.noise = init_noise(noise, self.varshape, num_vars=2)
# initializers
check_initializer(V_initializer, 'V_initializer')
@@ -426,25 +610,28 @@ class AdExIF(NeuGroup):
self._w_initializer = w_initializer
# variables
self.V = bm.Variable(init_param(V_initializer, (self.num,)) )
self.w = bm.Variable(init_param(w_initializer, (self.num,)) )
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool) )
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool) )
self.V = variable(V_initializer, mode, self.varshape )
self.w = variable(w_initializer, mode, self.varshape )
self.input = variable(bm.zeros, mode, self.varshape )
sp_type = bm.dftype() if isinstance(mode, BatchingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape )
# functions
self.integral = odeint(method=method, f=self.derivative)
if self.noise is None:
self.integral = odeint(method=method, f=self.derivative)
else:
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)
def reset(self):
self.V.value = init_param(self._V_initializer, (self.num,))
self.w.value = init_param(self._w_initializer, (self.num,) )
self.input[:] = 0
self.spike[:] = False
self.refractory[:] = False
def reset_state (self, batch_size=None ):
self.V.value = variable(self._V_initializer, batch_size, self.varshape )
self.w.value = variable(self._w_initializer, batch_size, self.varshape )
self.input.value = variable(bm.zeros, batch_size, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
def dV(self, V, t, w, I_ext):
dVdt = (- V + self.V_rest + self.delta_T * bm.exp((V - self.V_T) / self.delta_T) -
self.R * w + self.R * I_ext) / self.tau
exp = self.delta_T * bm.exp((V - self.V_T) / self.delta_T)
dVdt = (- V + self.V_rest + exp - self.R * w + self.R * I_ext) / self.tau
return dVdt
def dw(self, w, t, V):
@@ -455,12 +642,16 @@ class AdExIF(NeuGroup):
def derivative(self):
return JointEq([self.dV, self.dw])
def update(self, t, dt):
V, w = self.integral(self.V, self.w, t, self.input, dt=dt)
def update(self, tdi, x=None):
t, dt = tdi.t, tdi.dt
if x is not None: self.input += x
V, w = self.integral(self.V.value, self.w.value, t, self.input.value, dt)
spike = V >= self.V_th
self.V.value = bm.where(spike, self.V_reset, V)
self.w.value = bm.where(spike, w + self.b, w)
self.spike.value = spike
def clear_input(self):
self.input[:] = 0.
@@ -534,65 +725,91 @@ class QuaIF(NeuGroup):
def __init__(
self,
size: Shape,
V_rest: Union[float, Tensor, Initializer, Callable] = -65.,
V_reset: Union[float, Tensor, Initializer, Callable] = -68.,
V_th: Union[float, Tensor, Initializer, Callable] = -30.,
V_c: Union[float, Tensor, Initializer, Callable] = -50.0,
c: Union[float, Tensor, Initializer, Callable] = .07,
R: Union[float, Tensor, Initializer, Callable] = 1.,
tau: Union[float, Tensor, Initializer, Callable] = 10.,
tau_ref: Union[float, Tensor, Initializer, Callable] = 0.,
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
V_rest: Union[float, Array, Initializer, Callable] = -65.,
V_reset: Union[float, Array, Initializer, Callable] = -68.,
V_th: Union[float, Array, Initializer, Callable] = -30.,
V_c: Union[float, Array, Initializer, Callable] = -50.0,
c: Union[float, Array, Initializer, Callable] = .07,
R: Union[float, Array, Initializer, Callable] = 1.,
tau: Union[float, Array, Initializer, Callable] = 10.,
tau_ref: Union[float, Array, Initializer, Callable] = None,
V_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
noise: Union[float, Array, Initializer, Callable] = None,
keep_size: bool = False,
mode: Mode = normal,
method: str = 'exp_auto',
name: str = None
):
# initialization
super(QuaIF, self).__init__(size=size, name=name)
super(QuaIF, self).__init__(size=size,
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (TrainingMode, NormalMode), self.__class__)
# parameters
self.V_rest = init_param(V_rest, self.num, allow_none=False)
self.V_reset = init_param(V_reset, self.num, allow_none=False)
self.V_th = init_param(V_th, self.num, allow_none=False)
self.V_c = init_param(V_c, self.num, allow_none=False)
self.c = init_param(c, self.num, allow_none=False)
self.R = init_param(R, self.num, allow_none=False)
self.tau = init_param(tau, self.num, allow_none=False)
self.tau_ref = init_param(tau_ref, self.num, allow_none=False)
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
self.V_reset = parameter(V_reset, self.varshape, allow_none=False)
self.V_th = parameter(V_th, self.varshape, allow_none=False)
self.V_c = parameter(V_c, self.varshape, allow_none=False)
self.c = parameter(c, self.varshape, allow_none=False)
self.R = parameter(R, self.varshape, allow_none=False)
self.tau = parameter(tau, self.varshape, allow_none=False)
self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True)
self.noise = init_noise(noise, self.varshape, num_vars=1)
# initializers
check_initializer(V_initializer, '_V_initializer', allow_none=False)
self._V_initializer = V_initializer
# variables
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)
self.V = variable(V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, mode, self.varshape)
if self.tau_ref is not None:
self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
# integral
self.integral = odeint(method=method, f=self.derivative)
if self.noise is None:
self.integral = odeint(method=method, f=self.derivative)
else:
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)
def reset(self):
self.V.value = init_param(self._V_initializer, (self.num,))
self.input[:] = 0
self.spike[:] = False
self.t_last_spike[:] = -1e7
self.refractory[:] = False
def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape)
if self.tau_ref is not None:
self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
def derivative(self, V, t, I_ext):
dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I_ext) / self.tau
return dVdt
def update(self, t, dt, **kwargs):
refractory = (t - self.t_last_spike) <= self.tau_ref
V = self.integral(self.V, t, self.input, dt=dt)
V = bm.where(refractory, self.V, V)
spike = self.V_th <= V
self.t_last_spike.value = bm.where(spike, t, self.t_last_spike)
self.V.value = bm.where(spike, self.V_reset, V)
self.refractory.value = bm.logical_or(refractory, spike)
def update(self, tdi, x=None):
t, dt = tdi.t, tdi.dt
if x is not None: self.input += x
V = self.integral(self.V.value, t, self.input.value, dt)
if self.tau_ref is not None:
refractory = (t - self.t_last_spike) <= self.tau_ref
V = bm.where(refractory, self.V, V)
spike = self.V_th <= V
t_last_spike = bm.where(spike, t, self.t_last_spike)
V = bm.where(spike, self.V_reset, V)
self.refractory.value = bm.logical_or(refractory, spike)
else:
spike = self.V_th <= V
t_last_spike = bm.where(spike, t, self.t_last_spike)
V = bm.where(spike, self.V_reset, V)
self.V.value = V
self.spike.value = spike
self.t_last_spike.value = t_last_spike
def clear_input(self):
self.input[:] = 0.
@@ -676,32 +893,40 @@ class AdQuaIF(NeuGroup):
def __init__(
self,
size: Shape,
V_rest: Union[float, Tensor, Initializer, Callable] = -65.,
V_reset: Union[float, Tensor, Initializer, Callable] = -68.,
V_th: Union[float, Tensor, Initializer, Callable] = -30.,
V_c: Union[float, Tensor, Initializer, Callable] = -50.0,
a: Union[float, Tensor, Initializer, Callable] = 1.,
b: Union[float, Tensor, Initializer, Callable] = .1,
c: Union[float, Tensor, Initializer, Callable] = .07,
tau: Union[float, Tensor, Initializer, Callable] = 10.,
tau_w: Union[float, Tensor, Initializer, Callable] = 10.,
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
w_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
V_rest: Union[float, Array, Initializer, Callable] = -65.,
V_reset: Union[float, Array, Initializer, Callable] = -68.,
V_th: Union[float, Array, Initializer, Callable] = -30.,
V_c: Union[float, Array, Initializer, Callable] = -50.0,
a: Union[float, Array, Initializer, Callable] = 1.,
b: Union[float, Array, Initializer, Callable] = .1,
c: Union[float, Array, Initializer, Callable] = .07,
tau: Union[float, Array, Initializer, Callable] = 10.,
tau_w: Union[float, Array, Initializer, Callable] = 10.,
V_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
w_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
noise: Union[float, Array, Initializer, Callable] = None,
method: str = 'exp_auto',
keep_size: bool = False,
mode: Mode = normal,
name: str = None
):
super(AdQuaIF, self).__init__(size=size, name=name)
super(AdQuaIF, self).__init__(size=size,
keep_size=keep_size,
name=name,
mode=mode, )
check(self.mode, (TrainingMode, NormalMode), self.__class__)
# parameters
self.V_rest = init_param(V_rest, self.num, allow_none=False)
self.V_reset = init_param(V_reset, self.num, allow_none=False)
self.V_th = init_param(V_th, self.num, allow_none=False)
self.V_c = init_param(V_c, self.num, allow_none=False)
self.c = init_param(c, self.num, allow_none=False)
self.a = init_param(a, self.num, allow_none=False)
self.b = init_param(b, self.num, allow_none=False)
self.tau = init_param(tau, self.num, allow_none=False)
self.tau_w = init_param(tau_w, self.num, allow_none=False)
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
self.V_reset = parameter(V_reset, self.varshape, allow_none=False)
self.V_th = parameter(V_th, self.varshape, allow_none=False)
self.V_c = parameter(V_c, self.varshape, allow_none=False)
self.c = parameter(c, self.varshape, allow_none=False)
self.a = parameter(a, self.varshape, allow_none=False)
self.b = parameter(b, self.varshape, allow_none=False)
self.tau = parameter(tau, self.varshape, allow_none=False)
self.tau_w = parameter(tau_w, self.varshape, allow_none=False)
self.noise = init_noise(noise, self.varshape, num_vars=2)
# initializers
check_initializer(V_initializer, 'V_initializer', allow_none=False)
@@ -710,21 +935,26 @@ class AdQuaIF(NeuGroup):
self._w_initializer = w_initializer
# variables
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.w = bm.Variable(init_param(w_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
self.V = variable(V_initializer, mode, self.varshape)
self.w = variable(w_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
# integral
self.integral = odeint(method=method, f=self.derivative)
if self.noise is None:
self.integral = odeint(method=method, f=self.derivative)
else:
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)
def reset(self):
self.V.value = init_param(self._V_initializer, (self.num,))
self.w.value = init_param(self._w_initializer, (self.num,))
self.input[:] = 0
self.spike[:] = False
self.refractory[:] = False
def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.w.value = variable(self._w_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
def dV(self, V, t, w, I_ext):
dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + I_ext) / self.tau
@@ -738,12 +968,16 @@ class AdQuaIF(NeuGroup):
def derivative(self):
return JointEq([self.dV, self.dw])
def update(self, t, dt):
V, w = self.integral(self.V, self.w, t, self.input, dt=dt)
def update(self, tdi, x=None):
t, dt = tdi.t, tdi.dt
if x is not None: self.input += x
V, w = self.integral(self.V.value, self.w.value, t, self.input.value, dt)
spike = self.V_th <= V
self.V.value = bm.where(spike, self.V_reset, V)
self.w.value = bm.where(spike, w + self.b, w)
self.spike.value = spike
def clear_input(self):
self.input[:] = 0.
@@ -832,45 +1066,57 @@ class GIF(NeuGroup):
def __init__(
self,
size: Shape,
V_rest: Union[float, Tensor, Initializer, Callable] = -70.,
V_reset: Union[float, Tensor, Initializer, Callable] = -70.,
V_th_inf: Union[float, Tensor, Initializer, Callable] = -50.,
V_th_reset: Union[float, Tensor, Initializer, Callable] = -60.,
R: Union[float, Tensor, Initializer, Callable] = 20.,
tau: Union[float, Tensor, Initializer, Callable] = 20.,
a: Union[float, Tensor, Initializer, Callable] = 0.,
b: Union[float, Tensor, Initializer, Callable] = 0.01,
k1: Union[float, Tensor, Initializer, Callable] = 0.2,
k2: Union[float, Tensor, Initializer, Callable] = 0.02,
R1: Union[float, Tensor, Initializer, Callable] = 0.,
R2: Union[float, Tensor, Initializer, Callable] = 1.,
A1: Union[float, Tensor, Initializer, Callable] = 0.,
A2: Union[float, Tensor, Initializer, Callable] = 0.,
V_initializer: Union[Initializer, Callable, Tensor] = OneInit(-70.),
I1_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
I2_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
Vth_initializer: Union[Initializer, Callable, Tensor] = OneInit(-50.),
V_rest: Union[float, Array, Initializer, Callable] = -70.,
V_reset: Union[float, Array, Initializer, Callable] = -70.,
V_th_inf: Union[float, Array, Initializer, Callable] = -50.,
V_th_reset: Union[float, Array, Initializer, Callable] = -60.,
R: Union[float, Array, Initializer, Callable] = 20.,
tau: Union[float, Array, Initializer, Callable] = 20.,
a: Union[float, Array, Initializer, Callable] = 0.,
b: Union[float, Array, Initializer, Callable] = 0.01,
k1: Union[float, Array, Initializer, Callable] = 0.2,
k2: Union[float, Array, Initializer, Callable] = 0.02,
R1: Union[float, Array, Initializer, Callable] = 0.,
R2: Union[float, Array, Initializer, Callable] = 1.,
A1: Union[float, Array, Initializer, Callable] = 0.,
A2: Union[float, Array, Initializer, Callable] = 0.,
V_initializer: Union[Initializer, Callable, Array] = OneInit(-70.),
I1_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
I2_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
Vth_initializer: Union[Initializer, Callable, Array] = OneInit(-50.),
noise: Union[float, Array, Initializer, Callable] = None,
method: str = 'exp_auto',
name: str = None
keep_size: bool = False,
name: str = None,
# parameter for training
mode: Mode = normal,
spike_fun: Callable = bm.spike_with_sigmoid_grad,
):
# initialization
super(GIF, self).__init__(size=size, name=name)
super(GIF, self).__init__(size=size,
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (TrainingMode, NormalMode), self.__class__)
# params
self.V_rest = init_param(V_rest, self.num, allow_none=False)
self.V_reset = init_param(V_reset, self.num, allow_none=False)
self.V_th_inf = init_param(V_th_inf, self.num, allow_none=False)
self.V_th_reset = init_param(V_th_reset, self.num, allow_none=False)
self.R = init_param(R, self.num, allow_none=False)
self.tau = init_param(tau, self.num, allow_none=False)
self.a = init_param(a, self.num, allow_none=False)
self.b = init_param(b, self.num, allow_none=False)
self.k1 = init_param(k1, self.num, allow_none=False)
self.k2 = init_param(k2, self.num, allow_none=False)
self.R1 = init_param(R1, self.num, allow_none=False)
self.R2 = init_param(R2, self.num, allow_none=False)
self.A1 = init_param(A1, self.num, allow_none=False)
self.A2 = init_param(A2, self.num, allow_none=False)
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
self.V_reset = parameter(V_reset, self.varshape, allow_none=False)
self.V_th_inf = parameter(V_th_inf, self.varshape, allow_none=False)
self.V_th_reset = parameter(V_th_reset, self.varshape, allow_none=False)
self.R = parameter(R, self.varshape, allow_none=False)
self.tau = parameter(tau, self.varshape, allow_none=False)
self.a = parameter(a, self.varshape, allow_none=False)
self.b = parameter(b, self.varshape, allow_none=False)
self.k1 = parameter(k1, self.varshape, allow_none=False)
self.k2 = parameter(k2, self.varshape, allow_none=False)
self.R1 = parameter(R1, self.varshape, allow_none=False)
self.R2 = parameter(R2, self.varshape, allow_none=False)
self.A1 = parameter(A1, self.varshape, allow_none=False)
self.A2 = parameter(A2, self.varshape, allow_none=False)
self.noise = init_noise(noise, self.varshape, num_vars=4)
self.spike_fun = check_callable(spike_fun, 'spike_fun')
# initializers
check_initializer(V_initializer, 'V_initializer')
@@ -883,23 +1129,28 @@ class GIF(NeuGroup):
self._Vth_initializer = Vth_initializer
# variables
self.I1 = bm.Variable(init_param(I1_initializer, (self.num,)))
self.I2 = bm.Variable(init_param(I2_initializer, (self.num,)))
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.V_th = bm.Variable(init_param(Vth_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.I1 = variable(I1_initializer, mode, self.varshape)
self.I2 = variable(I2_initializer, mode, self.varshape)
self.V_th = variable(Vth_initializer, mode, self.varshape)
self.V = variable(V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
# integral
self.integral = odeint(method=method, f=self.derivative)
if self.noise is None:
self.integral = odeint(method=method, f=self.derivative)
else:
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)
def reset(self):
self.V.value = init_param(self._V_initializer, (self.num,))
self.I1.value = init_param(self._I1_initializer, (self.num,))
self.I2.value = init_param(self._I2_initializer, (self.num,))
self.V_th.value = init_param(self._Vth_initializer, (self.num,))
self.input[:] = 0
self.spike[:] = False
def reset_state(self, batch_size=None):
self.I1.value = variable(self._I1_initializer, batch_size, self.varshape)
self.I2.value = variable(self._I2_initializer, batch_size, self.varshape)
self.V_th.value = variable(self._Vth_initializer, batch_size, self.varshape)
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
def dI1(self, I1, t):
return - self.k1 * I1
@@ -917,19 +1168,199 @@ class GIF(NeuGroup):
def derivative(self):
return JointEq([self.dI1, self.dI2, self.dVth, self.dV])
def update(self, t, dt):
def update(self, tdi, x=None):
t, dt = tdi.t, tdi.dt
# integral
if x is not None: self.input += x
I1, I2, V_th, V = self.integral(self.I1, self.I2, self.V_th, self.V, t, self.input, dt=dt)
spike = self.V_th <= V
V = bm.where(spike, self.V_reset, V)
I1 = bm.where(spike, self.R1 * I1 + self.A1, I1)
I2 = bm.where(spike, self.R2 * I2 + self.A2, I2)
reset_th = bm.logical_and(V_th < self.V_th_reset, spike)
V_th = bm.where(reset_th, self.V_th_reset, V_th)
# spike and resets
if isinstance(self.mode, TrainingMode):
spike = self.spike_fun(V - self.V_th)
V += (self.V_reset - V) * spike
I1 += spike * (self.R1 * I1 + self.A1 - I1)
I2 += spike * (self.R2 * I2 + self.A2 - I2)
reset_th = self.spike_fun(self.V_th_reset - V_th) * spike
V_th += reset_th * (self.V_th_reset - V_th)
else:
spike = self.V_th <= V
V = bm.where(spike, self.V_reset, V)
I1 = bm.where(spike, self.R1 * I1 + self.A1, I1)
I2 = bm.where(spike, self.R2 * I2 + self.A2, I2)
reset_th = bm.logical_and(V_th < self.V_th_reset, spike)
V_th = bm.where(reset_th, self.V_th_reset, V_th)
self.spike.value = spike
self.I1.value = I1
self.I2.value = I2
self.V_th.value = V_th
self.V.value = V
def clear_input(self):
self.input[:] = 0.
class ALIFBellec2020(NeuGroup):
r"""Leaky Integrate-and-Fire model with SFA [1]_.
This model is similar to the GLIF2 model in the Technical White Paper
on generalized LIF (GLIF) models from AllenInstitute [2]_.
Formally, this model is given by:
.. math::
\tau \dot{V} = -(V - V_{\mathrm{rest}}) + R*I \\
\tau_a \dot{a} = -a
Once a spike is induced by :math:`V(t) > V_{\mathrm{th}} + \beta a`, then
.. math::
V \gets V - V_{\mathrm{th}} \\
a \gets a + 1
References
----------
.. [1] Bellec, Guillaume, et al. "A solution to the learning dilemma for
recurrent networks of spiking neurons."
Nature communications 11.1 (2020): 1-15.
.. [2] Allen Institute: Cell Types Database. © 2018 Allen Institute for
Brain Science. Allen Cell Types Database, cell feature search.
Available from: celltypes.brain-map.org/data (2018).
"""
def __init__(
self,
size: Shape,
keep_size: bool = False,
# model parameters
V_rest: Union[float, Array, Initializer, Callable] = -70.,
V_th: Union[float, Array, Initializer, Callable] = -60.,
R: Union[float, Array, Initializer, Callable] = 1.,
beta: Union[float, Array, Initializer, Callable] = 1.6,
tau: Union[float, Array, Initializer, Callable] = 20.,
tau_a: Union[float, Array, Initializer, Callable] = 2000.,
tau_ref: Union[float, Array, Initializer, Callable] = None,
noise: Union[float, Array, Initializer, Callable] = None,
# initializers
V_initializer: Union[Initializer, Callable, Array] = OneInit(-70.),
a_initializer: Union[Initializer, Callable, Array] = OneInit(-50.),
# parameter for training
spike_fun: Callable = bm.spike_with_linear_grad,
# other parameters
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
eprop: bool = False
):
super(ALIFBellec2020, self).__init__(name=name,
size=size,
keep_size=keep_size,
mode=mode)
check(self.mode, (TrainingMode, NormalMode), self.__class__)
# parameters
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
self.V_th = parameter(V_th, self.varshape, allow_none=False)
self.R = parameter(R, self.varshape, allow_none=False)
self.beta = parameter(beta, self.varshape, allow_none=False)
self.tau = parameter(tau, self.varshape, allow_none=False)
self.tau_a = parameter(tau_a, self.varshape, allow_none=False)
self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True)
self.noise = init_noise(noise, self.varshape, num_vars=2)
self.spike_fun = check_callable(spike_fun, 'spike_fun')
self.eprop = eprop
# initializers
check_initializer(V_initializer, 'V_initializer')
check_initializer(a_initializer, 'a_initializer')
self._V_initializer = V_initializer
self._a_initializer = a_initializer
# variables
self.a = variable(a_initializer, mode, self.varshape)
self.V = variable(V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
if self.tau_ref is not None:
self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, mode, self.varshape)
self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
# integral
if self.noise is None:
self.integral = odeint(method=method, f=self.derivative)
else:
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)
def da(self, a, t):
return -a / self.tau_a
def dV(self, V, t, I_ext):
return (- (V - self.V_rest) + self.R * I_ext) / self.tau
@property
def derivative(self):
return JointEq([self.dV, self.da])
def reset_state(self, batch_size=None):
self.a.value = variable(self._a_initializer, batch_size, self.varshape)
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
if self.tau_ref is not None:
self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape)
self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
def update(self, tdi, x=None):
t, dt = tdi.t, tdi.dt
# integral
if x is not None: self.input += x
V, a = self.integral(self.V, self.a, t, self.input, dt)
if self.tau_ref is not None:
# refractory
refractory = (t - self.t_last_spike) <= self.tau_ref
if isinstance(self.mode, TrainingMode):
refractory = stop_gradient(refractory)
V = bm.where(refractory, self.V, V)
# spike and reset
if isinstance(self.mode, TrainingMode):
spike = self.spike_fun((V - self.V_th - self.beta * self.a) / self.V_th)
V -= self.V_th * (stop_gradient(spike) if self.eprop else spike)
# will be used in other place, like Delta Synapse, so stop its gradient
spike_ = spike > 0.
refractory = stop_gradient(bm.logical_or(refractory, spike_).value)
t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike).value)
else:
spike = V >= (self.V_th + self.beta * self.a)
refractory = bm.logical_or(refractory, spike)
t_last_spike = bm.where(spike, t, self.t_last_spike)
V -= self.V_th * spike
self.refractory.value = refractory
self.t_last_spike.value = t_last_spike
else:
# spike and reset
if isinstance(self.mode, TrainingMode):
spike = self.spike_fun((V - self.V_th - self.beta * self.a) / self.V_th)
V -= self.V_th * (stop_gradient(spike) if self.eprop else spike)
else:
spike = V >= (self.V_th + self.beta * self.a)
V -= self.V_th * spike
self.spike.value = spike
self.V.value = V
self.a.value = a + spike
def clear_input(self):
self.input[:] = 0.
@@ -1004,27 +1435,37 @@ class Izhikevich(NeuGroup):
def __init__(
self,
size: Shape,
a: Union[float, Tensor, Initializer, Callable] = 0.02,
b: Union[float, Tensor, Initializer, Callable] = 0.20,
c: Union[float, Tensor, Initializer, Callable] = -65.,
d: Union[float, Tensor, Initializer, Callable] = 8.,
V_th: Union[float, Tensor, Initializer, Callable] = 30.,
tau_ref: Union[float, Tensor, Initializer, Callable] = 0.,
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
u_initializer: Union[Initializer, Callable, Tensor] = OneInit(),
a: Union[float, Array, Initializer, Callable] = 0.02,
b: Union[float, Array, Initializer, Callable] = 0.20,
c: Union[float, Array, Initializer, Callable] = -65.,
d: Union[float, Array, Initializer, Callable] = 8.,
V_th: Union[float, Array, Initializer, Callable] = 30.,
tau_ref: Union[float, Array, Initializer, Callable] = None,
V_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
u_initializer: Union[Initializer, Callable, Array] = OneInit(),
noise: Union[float, Array, Initializer, Callable] = None,
method: str = 'exp_auto',
mode: Mode = normal,
spike_fun: Callable = bm.spike_with_sigmoid_grad,
keep_size: bool = False,
name: str = None
):
# initialization
super(Izhikevich, self).__init__(size=size, name=name)
super(Izhikevich, self).__init__(size=size,
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (TrainingMode, NormalMode), self.__class__)
# params
self.a = init_param(a, self.num, allow_none=False)
self.b = init_param(b, self.num, allow_none=False)
self.c = init_param(c, self.num, allow_none=False)
self.d = init_param(d, self.num, allow_none=False)
self.V_th = init_param(V_th, self.num, allow_none=False)
self.tau_ref = init_param(tau_ref, self.num, allow_none=False)
self.a = parameter(a, self.varshape, allow_none=False)
self.b = parameter(b, self.varshape, allow_none=False)
self.c = parameter(c, self.varshape, allow_none=False)
self.d = parameter(d, self.varshape, allow_none=False)
self.V_th = parameter(V_th, self.varshape, allow_none=False)
self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True)
self.noise = init_noise(noise, self.varshape, num_vars=2)
self.spike_fun = check_callable(spike_fun, 'spike_fun')
# initializers
check_initializer(V_initializer, 'V_initializer', allow_none=False)
@@ -1033,23 +1474,30 @@ class Izhikevich(NeuGroup):
self._u_initializer = u_initializer
# variables
self.u = bm.Variable(init_param(u_initializer, (self.num,)))
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)
self.u = variable(u_initializer, mode, self.varshape)
self.V = variable(V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
if self.tau_ref is not None:
self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, mode, self.varshape)
self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
# functions
self.integral = odeint(method=method, f=JointEq([self.dV, self.du]))
def reset(self):
self.V.value = init_param(self._V_initializer, (self.num,))
self.u.value = init_param(self._u_initializer, (self.num,))
self.input[:] = 0
self.spike[:] = False
self.refractory[:] = False
self.t_last_spike[:] = -1e7
if self.noise is None:
self.integral = odeint(method=method, f=JointEq([self.dV, self.du]))
else:
self.integral = sdeint(method=method, f=JointEq([self.dV, self.du]), g=self.noise)
def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.u.value = variable(self._u_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
if self.tau_ref is not None:
self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape)
self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
def dV(self, V, t, u, I_ext):
dVdt = 0.04 * V * V + 5 * V + 140 - u + I_ext
@@ -1059,16 +1507,55 @@ class Izhikevich(NeuGroup):
dudt = self.a * (self.b * V - u)
return dudt
def update(self, t, dt):
V, u = self.integral(self.V, self.u, t, self.input, dt=dt)
refractory = (t - self.t_last_spike) <= self.tau_ref
V = bm.where(refractory, self.V, V)
spike = self.V_th <= V
self.t_last_spike.value = bm.where(spike, t, self.t_last_spike)
self.V.value = bm.where(spike, self.c, V)
self.u.value = bm.where(spike, u + self.d, u)
self.refractory.value = bm.logical_or(refractory, spike)
def update(self, tdi, x=None):
t, dt = tdi.t, tdi.dt
# integrate membrane potential
if x is not None: self.input += x
V, u = self.integral(self.V, self.u, t, self.input, dt)
if self.tau_ref is not None:
refractory = (t - self.t_last_spike) <= self.tau_ref
if isinstance(self.mode, TrainingMode):
refractory = stop_gradient(refractory)
V = bm.where(refractory, self.V, V)
# spike, refractory, and reset membrane potential
if isinstance(self.mode, TrainingMode):
spike = self.spike_fun(V - self.V_th)
spike_no_grad = stop_gradient(spike)
V += spike_no_grad * (self.c - self.V_th)
u += spike_no_grad * self.d
spike_ = spike_no_grad > 0.
refractory = stop_gradient(bm.logical_or(refractory, spike_).value)
t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike).value)
else:
spike = self.V_th <= V
V = bm.where(spike, self.c, V)
u = bm.where(spike, u + self.d, u)
refractory = bm.logical_or(refractory, spike)
t_last_spike = bm.where(spike, t, self.t_last_spike)
self.refractory.value = refractory
self.t_last_spike.value = t_last_spike
else:
# spike, refractory, and reset membrane potential
if isinstance(self.mode, TrainingMode):
spike = self.spike_fun(V - self.V_th)
spike_no_grad = stop_gradient(spike)
V += spike_no_grad * (self.c - self.V_th)
u += spike_no_grad * self.d
else:
spike = self.V_th <= V
V = bm.where(spike, self.c, V)
u = bm.where(spike, u + self.d, u)
# finally
self.V.value = V
self.u.value = u
self.spike.value = spike
def clear_input(self):
self.input[:] = 0.
@@ -1173,32 +1660,44 @@ class HindmarshRose(NeuGroup):
def __init__(
self,
size: Shape,
a: Union[float, Tensor, Initializer, Callable] = 1.,
b: Union[float, Tensor, Initializer, Callable] = 3.,
c: Union[float, Tensor, Initializer, Callable] = 1.,
d: Union[float, Tensor, Initializer, Callable] = 5.,
r: Union[float, Tensor, Initializer, Callable] = 0.01,
s: Union[float, Tensor, Initializer, Callable] = 4.,
V_rest: Union[float, Tensor, Initializer, Callable] = -1.6,
V_th: Union[float, Tensor, Initializer, Callable] = 1.0,
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
y_initializer: Union[Initializer, Callable, Tensor] = OneInit(-10.),
z_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
a: Union[float, Array, Initializer, Callable] = 1.,
b: Union[float, Array, Initializer, Callable] = 3.,
c: Union[float, Array, Initializer, Callable] = 1.,
d: Union[float, Array, Initializer, Callable] = 5.,
r: Union[float, Array, Initializer, Callable] = 0.01,
s: Union[float, Array, Initializer, Callable] = 4.,
V_rest: Union[float, Array, Initializer, Callable] = -1.6,
V_th: Union[float, Array, Initializer, Callable] = 1.0,
V_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
y_initializer: Union[Initializer, Callable, Array] = OneInit(-10.),
z_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
noise: Union[float, Array, Initializer, Callable] = None,
method: str = 'exp_auto',
name: str = None
keep_size: bool = False,
name: str = None,
# parameters for training
mode: Mode = normal,
spike_fun: Callable = bm.spike2_with_sigmoid_grad,
):
# initialization
super(HindmarshRose, self).__init__(size=size, name=name)
super(HindmarshRose, self).__init__(size=size,
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (TrainingMode, NormalMode), self.__class__)
# parameters
self.a = init_param(a, self.num, allow_none=False)
self.b = init_param(b, self.num, allow_none=False)
self.c = init_param(c, self.num, allow_none=False)
self.d = init_param(d, self.num, allow_none=False)
self.r = init_param(r, self.num, allow_none=False)
self.s = init_param(s, self.num, allow_none=False)
self.V_th = init_param(V_th, self.num, allow_none=False)
self.V_rest = init_param(V_rest, self.num, allow_none=False)
self.a = parameter(a, self.varshape, allow_none=False)
self.b = parameter(b, self.varshape, allow_none=False)
self.c = parameter(c, self.varshape, allow_none=False)
self.d = parameter(d, self.varshape, allow_none=False)
self.r = parameter(r, self.varshape, allow_none=False)
self.s = parameter(s, self.varshape, allow_none=False)
self.V_th = parameter(V_th, self.varshape, allow_none=False)
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
self.noise = init_noise(noise, self.varshape, num_vars=3)
self.spike_fun = check_callable(spike_fun, 'spike_fun')
# variables
check_initializer(V_initializer, 'V_initializer', allow_none=False)
@@ -1209,21 +1708,26 @@ class HindmarshRose(NeuGroup):
self._z_initializer = z_initializer
# variables
self.z = bm.Variable(init_param(V_initializer, (self.num,)))
self.y = bm.Variable(init_param(y_initializer, (self.num,)))
self.V = bm.Variable(init_param(z_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.V = variable(self._V_initializer, mode, self.varshape)
self.y = variable(self._y_initializer, mode, self.varshape)
self.z = variable(self._z_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
# integral
self.integral = odeint(method=method, f=self.derivative)
if self.noise is None:
self.integral = odeint(method=method, f=self.derivative)
else:
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)
def reset(self):
self.V.value = init_param(self._V_initializer, (self.num,))
self.y.value = init_param(self._y_initializer, (self.num,))
self.z.value = init_param(self._z_initializer, (self.num,))
self.input[:] = 0
self.spike[:] = False
def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.y.value = variable(self._y_initializer, batch_size, self.varshape)
self.z.value = variable(self._z_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
def dV(self, V, t, y, z, I_ext):
return y - self.a * V * V * V + self.b * V * V - z + I_ext
@@ -1238,12 +1742,19 @@ class HindmarshRose(NeuGroup):
def derivative(self):
return JointEq([self.dV, self.dy, self.dz])
def update(self, t, dt):
def update(self, tdi, x=None):
t, dt = tdi.t, tdi.dt
if x is not None: self.input += x
V, y, z = self.integral(self.V, self.y, self.z, t, self.input, dt=dt)
self.spike.value = bm.logical_and(V >= self.V_th, self.V < self.V_th)
if isinstance(self.mode, TrainingMode):
self.spike.value = self.spike_fun(V - self.V_th, self.V - self.V_th)
else:
self.spike.value = bm.logical_and(V >= self.V_th, self.V < self.V_th)
self.V.value = V
self.y.value = y
self.z.value = z
def clear_input(self):
self.input[:] = 0.
@@ -1333,23 +1844,35 @@ class FHN(NeuGroup):
def __init__(
self,
size: Shape,
a: Union[float, Tensor, Initializer, Callable] = 0.7,
b: Union[float, Tensor, Initializer, Callable] = 0.8,
tau: Union[float, Tensor, Initializer, Callable] = 12.5,
Vth: Union[float, Tensor, Initializer, Callable] = 1.8,
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
w_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
a: Union[float, Array, Initializer, Callable] = 0.7,
b: Union[float, Array, Initializer, Callable] = 0.8,
tau: Union[float, Array, Initializer, Callable] = 12.5,
Vth: Union[float, Array, Initializer, Callable] = 1.8,
V_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
w_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
noise: Union[float, Array, Initializer, Callable] = None,
method: str = 'exp_auto',
name: str = None
keep_size: bool = False,
name: str = None,
# parameters for training
mode: Mode = normal,
spike_fun: Callable = bm.spike2_with_sigmoid_grad,
):
# initialization
super(FHN, self).__init__(size=size, name=name)
super(FHN, self).__init__(size=size,
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (TrainingMode, NormalMode), self.__class__)
# parameters
self.a = init_param(a, self.num, allow_none=False)
self.b = init_param(b, self.num, allow_none=False)
self.tau = init_param(tau, self.num, allow_none=False)
self.Vth = init_param(Vth, self.num, allow_none=False)
self.a = parameter(a, self.varshape, allow_none=False)
self.b = parameter(b, self.varshape, allow_none=False)
self.tau = parameter(tau, self.varshape, allow_none=False)
self.Vth = parameter(Vth, self.varshape, allow_none=False)
self.noise = init_noise(noise, self.varshape, num_vars=2)
self.spike_fun = check_callable(spike_fun, 'spike_fun')
# initializers
check_initializer(V_initializer, 'V_initializer')
@@ -1358,19 +1881,24 @@ class FHN(NeuGroup):
self._w_initializer = w_initializer
# variables
self.w = bm.Variable(init_param(w_initializer, (self.num,)))
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.V = variable(self._V_initializer, mode, self.varshape)
self.w = variable(self._w_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
# integral
self.integral = odeint(method=method, f=self.derivative)
if self.noise is None:
self.integral = odeint(method=method, f=self.derivative)
else:
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)
def reset(self):
self.V.value = init_param(self._V_initializer, (self.num,))
self.w.value = init_param(self._w_initializer, (self.num,))
self.input[:] = 0
self.spike[:] = False
def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.w.value = variable(self._w_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
def dV(self, V, t, w, I_ext):
return V - V * V * V / 3 - w + I_ext
@@ -1382,9 +1910,16 @@ class FHN(NeuGroup):
def derivative(self):
return JointEq([self.dV, self.dw])
def update(self, t, dt):
V, w = self.integral(self.V, self.w, t, self.input, dt=dt)
self.spike.value = bm.logical_and(V >= self.Vth, self.V < self.Vth)
def update(self, tdi, x=None):
t, dt = tdi.t, tdi.dt
if x is not None: self.input += x
V, w = self.integral(self.V.value, self.w.value, t, self.input.value, dt=dt)
if isinstance(self.mode, TrainingMode):
self.spike.value = self.spike_fun(V - self.Vth, self.V - self.Vth)
else:
self.spike.value = bm.logical_and(V >= self.Vth, self.V < self.Vth)
self.V.value = V
self.w.value = w
def clear_input(self):
self.input[:] = 0.