Browse Source

update synapses (#167)

update synapses
pull/30/head
Chaoming Wang GitHub 1 year ago
parent
commit
d0aee2c223
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 168 additions and 251 deletions
  1. +2
    -1
      .github/workflows/Windows_CI.yml
  2. +40
    -0
      brainpy/dyn/base.py
  3. +108
    -226
      brainpy/dyn/synapses/abstract_models.py
  4. +12
    -13
      brainpy/dyn/synapses/biological_models.py
  5. +2
    -7
      brainpy/dyn/synapses/delay_coupling.py
  6. +2
    -2
      brainpy/math/delayvars.py
  7. +2
    -2
      requirements-win.txt

+ 2
- 1
.github/workflows/Windows_CI.yml View File

@@ -28,7 +28,8 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
python -m pip install jax[cpu] -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
python -m pip install numpy==1.21.0
python -m pip install "jax[cpu]==0.3.5" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
python -m pip install -r requirements-win.txt
python -m pip install tqdm brainpylib
python setup.py install


+ 40
- 0
brainpy/dyn/base.py View File

@@ -12,6 +12,7 @@ from brainpy import tools
from brainpy.base.base import Base
from brainpy.base.collector import Collector
from brainpy.connect import TwoEndConnector, MatConn, IJConn
from brainpy.initialize import Initializer, ZeroInit
from brainpy.errors import ModelBuildError
from brainpy.integrators.base import Integrator
from brainpy.types import Tensor
@@ -337,6 +338,10 @@ class TwoEndConn(DynamicalSystem):
The name of the dynamic system.
"""

"""Global delay variables. Useful when the same target
variable is used in multiple mappings."""
global_delay_vars: Dict[str, bm.LengthDelay] = dict()

def __init__(
self,
pre: NeuGroup,
@@ -344,6 +349,9 @@ class TwoEndConn(DynamicalSystem):
conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]] = None,
name: str = None
):
# local delay variables
self.local_delay_vars: Dict[str, bm.LengthDelay] = dict()

# pre or post neuron group
# ------------------------
if not isinstance(pre, NeuGroup):
@@ -399,3 +407,35 @@ class TwoEndConn(DynamicalSystem):
raise ValueError(f'Must be string. But got {attr}.')
if not hasattr(self.post, attr):
raise ModelBuildError(f'{self} need "pre" neuron group has attribute "{attr}".')

def register_delay(
self,
name: str,
num_delay: int,
variable: Union[bm.JaxArray, jnp.ndarray],
delay_initializer: Initializer = ZeroInit(),
domain: str = 'global'
):
if domain not in ['global', 'local']:
raise ValueError('"domain" must be a string in ["global", "local"]. '
f'Bug we got {domain}.')
num_delay = int(num_delay)
delay_data = delay_initializer((num_delay,) + variable.shape, dtype=variable.dtype)

if domain == 'local':
self.local_delay_vars[name] = bm.LengthDelay(variable, num_delay, delay_data)
else:
if name not in self.global_delay_vars:
self.global_delay_vars[name] = bm.LengthDelay(variable, num_delay, delay_data)
# save into local delay vars when first seen "var",
# for later update current value!
self.local_delay_vars[name] = self.global_delay_vars[name]
else:
if self.global_delay_vars[name].num_delay_step - 1 < num_delay:
self.global_delay_vars[name].init(variable, num_delay, delay_data)

def get_delay(self):
pass

def update_delay(self):
pass

+ 108
- 226
brainpy/dyn/synapses/abstract_models.py View File

@@ -1,12 +1,13 @@
# -*- coding: utf-8 -*-

from typing import Union, Dict
from typing import Union, Dict, Callable

import brainpy.math as bm
from brainpy.connect import TwoEndConnector, All2All, One2One
from brainpy.dyn.base import NeuGroup, TwoEndConn
from brainpy.dyn.utils import init_delay
from brainpy.integrators import odeint
from brainpy.initialize import Initializer
from brainpy.integrators import odeint, JointEq
from brainpy.types import Tensor, Parameter

__all__ = [
@@ -59,14 +60,6 @@ class DeltaSynapse(TwoEndConn):
>>> plt.legend()
>>> plt.show()

**Model Parameters**

============= ============== ======== ===========================================
**Parameter** **Init Value** **Unit** **Explanation**
------------- -------------- -------- -------------------------------------------
w 1 mV The synaptic strength.
============= ============== ======== ===========================================

Parameters
----------
pre: NeuGroup
@@ -75,8 +68,18 @@ class DeltaSynapse(TwoEndConn):
The post-synaptic neuron group.
conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector
The synaptic connections.
conn_type: str
The connection type used for model speed optimization. It can be
`sparse` and `dense`.
delay_step: int, ndarray, JaxArray
The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
w: float, ndarray, JaxArray, Initializer
The synaptic strength. Default is 1.
post_key: str
The key of the post variable. It should be a string. The key should
be the attribute of the post-synaptic neuron group.
post_has_ref: bool
Whether the post-synaptic group has refractory period.
"""

def __init__(
@@ -84,27 +87,36 @@ class DeltaSynapse(TwoEndConn):
pre: NeuGroup,
post: NeuGroup,
conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]],
w: Parameter = 1.,
delay_step: Parameter = None,
conn_type: str = 'sparse',
w: Union[float, Tensor, Initializer, Callable] = 1.,
delay_step: Union[float, Tensor, Initializer, Callable] = None,
post_key='V',
post_has_ref=False,
name: str = None,
):
super(DeltaSynapse, self).__init__(pre=pre, post=post, conn=conn, name=name)
self.check_pre_attrs('spike')
self.check_post_attrs(post_key)

# parameters
self.post_key = post_key
self.check_post_attrs(post_key)
self.post_has_ref = post_has_ref
if post_has_ref: # checking
if post_has_ref:
self.check_post_attrs('refractory')
self.w = w

# connections
assert self.conn is not None
self.conn_type = conn_type
if conn_type not in ['sparse', 'dense']:
raise ValueError(f'"conn_type" must be in "sparse" and "dense", but we got {conn_type}')
if self.conn is None:
raise ValueError(f'Must provide "conn" when initialize the model {self.name}')
if not isinstance(self.conn, (All2All, One2One)):
self.pre2post = self.conn.require('pre2post')
if conn_type == 'sparse':
self.pre2post = self.conn.require('pre2post')
elif conn_type == 'dense':
self.conn_mat = self.conn.require('conn_mat')
raise ValueError(f'Unknown connection type: {conn_type}')

# variables
delay_type, delay_step, delay_value = init_delay(delay_step, self.pre.spike)
@@ -132,7 +144,13 @@ class DeltaSynapse(TwoEndConn):
elif isinstance(self.conn, One2One):
post_vs = pre_spike * self.w
else:
post_vs = bm.pre2post_event_sum(pre_spike, self.pre2post, self.post.num, self.w)
if self.conn_type == 'sparse':
post_vs = bm.pre2post_event_sum(pre_spike,
self.pre2post,
self.post.num,
self.w)
else:
post_vs = self.w * pre_spike @ self.conn_mat

# updates
target = getattr(self.post, self.post_key)
@@ -142,162 +160,6 @@ class DeltaSynapse(TwoEndConn):
target += post_vs


class Exponential(TwoEndConn):
r"""Current-based exponential decay synapse model.

**Model Descriptions**

The single exponential decay synapse model assumes the release of neurotransmitter,
its diffusion across the cleft, the receptor binding, and channel opening all happen
very quickly, so that the channels instantaneously jump from the closed to the open state.
Therefore, its expression is given by

.. math::

g_{\mathrm{syn}}(t)=g_{\mathrm{max}} e^{-\left(t-t_{0}\right) / \tau}

where :math:`\tau_{delay}` is the time constant of the synaptic state decay,
:math:`t_0` is the time of the pre-synaptic spike,
:math:`g_{\mathrm{max}}` is the maximal conductance.

Accordingly, the differential form of the exponential synapse is given by

.. math::

\begin{aligned}
& g_{\mathrm{syn}}(t) = g_{max} g \\
& \frac{d g}{d t} = -\frac{g}{\tau_{decay}}+\sum_{k} \delta(t-t_{j}^{k}).
\end{aligned}

For the current output onto the post-synaptic neuron, its expression is given by

.. math::

I_{\mathrm{syn}}(t) = g_{\mathrm{syn}}(t)


**Model Examples**

- `(Brunel & Hakim, 1999) Fast Global Oscillation <https://brainpy-examples.readthedocs.io/en/latest/oscillation_synchronization/Brunel_Hakim_1999_fast_oscillation.html>`_
- `(Vreeswijk & Sompolinsky, 1996) E/I balanced network <https://brainpy-examples.readthedocs.io/en/latest/ei_nets/Vreeswijk_1996_EI_net.html>`_
- `(Brette, et, al., 2007) CUBA <https://brainpy-examples.readthedocs.io/en/latest/ei_nets/Brette_2007_CUBA.html>`_
- `(Tian, et al., 2020) E/I Net for fast response <https://brainpy-examples.readthedocs.io/en/latest/ei_nets/Tian_2020_EI_net_for_fast_response.html>`_

.. plot::
:include-source: True

>>> import brainpy as bp
>>> import matplotlib.pyplot as plt
>>>
>>> neu1 = bp.dyn.LIF(1)
>>> neu2 = bp.dyn.LIF(1)
>>> syn1 = bp.dyn.ExpCUBA(neu1, neu2, bp.conn.All2All(), g_max=5.)
>>> net = bp.dyn.Network(pre=neu1, syn=syn1, post=neu2)
>>>
>>> runner = bp.dyn.DSRunner(net, inputs=[('pre.input', 25.)], monitors=['pre.V', 'post.V', 'syn.g'])
>>> runner.run(150.)
>>>
>>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
>>> fig.add_subplot(gs[0, 0])
>>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V')
>>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V')
>>> plt.legend()
>>>
>>> fig.add_subplot(gs[1, 0])
>>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g')
>>> plt.legend()
>>> plt.show()


**Model Parameters**

============= ============== ======== ===================================================================================
**Parameter** **Init Value** **Unit** **Explanation**
------------- -------------- -------- -----------------------------------------------------------------------------------
delay 0 ms The decay length of the pre-synaptic spikes.
tau_decay 8 ms The time constant of decay.
g_max 1 µmho(µS) The maximum conductance.
============= ============== ======== ===================================================================================

**Model Variables**

================ ================== =========================================================
**Member name** **Initial values** **Explanation**
---------------- ------------------ ---------------------------------------------------------
g 0 Gating variable.
pre_spike False The history spiking states of the pre-synaptic neurons.
================ ================== =========================================================

**References**

.. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw.
"The Synapse." Principles of Computational Modelling in Neuroscience.
Cambridge: Cambridge UP, 2011. 172-95. Print.
"""
def __init__(
self,
pre: NeuGroup,
post: NeuGroup,
conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]],
g_max: Parameter = 1.,
tau: Parameter = 8.0,
method: str = 'exp_auto',
out_type='coba',
delay_step=None,
name: str = None
):
super(Exponential, self).__init__(pre=pre, post=post, conn=conn, name=name)
self.check_pre_attrs('spike')
self.check_post_attrs('input', 'V')

# parameters
self.tau = tau
self.g_max = g_max
self.out_type = out_type
if out_type not in ['coba', 'cuba']:
raise ValueError

# connection
assert self.conn is not None
if not isinstance(self.conn, (All2All, One2One)):
self.pre2post = self.conn.require('pre2post')

# variables
self.g = bm.Variable(bm.zeros(self.post.num))
delay_type, delay_step, delay_value = init_delay(delay_step, self.pre.spike)
self.delay_type = delay_type
self.delay_step = delay_step
self.pre_spike = delay_value

# function
self.integral = odeint(lambda g, t: -g / self.tau, method=method)

def update(self, _t, _dt):
if self.delay_type == 'homo':
delayed_pre_spike = self.pre_spike(self.delay_step)
self.pre_spike.update(self.pre.spike)
elif self.delay_type == 'heter':
delayed_pre_spike = self.pre_spike(self.delay_step, bm.arange(self.pre.num))
self.pre_spike.update(self.pre.spike)
else:
delayed_pre_spike = self.pre.spike

# post values
if isinstance(self.conn, All2All):
post_vs = bm.sum(delayed_pre_spike)
if not self.conn.include_self:
post_vs = post_vs - delayed_pre_spike
post_vs *= self.g_max
elif isinstance(self.conn, One2One):
post_vs = delayed_pre_spike * self.g_max
else:
post_vs = bm.pre2post_event_sum(delayed_pre_spike, self.pre2post, self.post.num, self.g_max)

# updates
self.g.value = self.integral(self.g.value, _t, dt=_dt) + post_vs
self.post.input += self.g


class ExpCUBA(TwoEndConn):
r"""Current-based exponential decay synapse model.

@@ -396,8 +258,9 @@ class ExpCUBA(TwoEndConn):
pre: NeuGroup,
post: NeuGroup,
conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]],
g_max: Parameter = 1.,
tau: Parameter = 8.0,
conn_type: str = 'sparse',
g_max: Union[float, Tensor, Initializer, Callable] = 1.,
tau: Union[float, Tensor, Initializer, Callable] = 8.0,
method: str = 'exp_auto',
delay_step=None,
name: str = None
@@ -411,9 +274,17 @@ class ExpCUBA(TwoEndConn):
self.g_max = g_max

# connection
assert self.conn is not None
self.conn_type = conn_type
if conn_type not in ['sparse', 'dense']:
raise ValueError(f'"conn_type" must be in "sparse" and "dense", but we got {conn_type}')
if self.conn is None:
raise ValueError(f'Must provide "conn" when initialize the model {self.name}')
if not isinstance(self.conn, (All2All, One2One)):
self.pre2post = self.conn.require('pre2post')
if conn_type == 'sparse':
self.pre2post = self.conn.require('pre2post')
elif conn_type == 'dense':
self.conn_mat = self.conn.require('conn_mat')
raise ValueError(f'Unknown connection type: {conn_type}')

# variables
self.g = bm.Variable(bm.zeros(self.post.num))
@@ -444,7 +315,13 @@ class ExpCUBA(TwoEndConn):
elif isinstance(self.conn, One2One):
post_vs = delayed_pre_spike * self.g_max
else:
post_vs = bm.pre2post_event_sum(delayed_pre_spike, self.pre2post, self.post.num, self.g_max)
if self.conn_type == 'sparse':
post_vs = bm.pre2post_event_sum(delayed_pre_spike,
self.pre2post,
self.post.num,
self.g_max)
else:
post_vs = delayed_pre_spike @ self.conn_mat

# updates
self.g.value = self.integral(self.g.value, _t, dt=_dt) + post_vs
@@ -529,15 +406,22 @@ class ExpCOBA(ExpCUBA):
self,
pre: NeuGroup,
post: NeuGroup,
# connection
conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]],
g_max: Parameter = 1.,
tau: Parameter = 8.0,
E: Parameter = 0.,
delay_step: Parameter = None,
conn_type: str = 'sparse',
# connection strength
g_max: Union[float, Tensor, Initializer, Callable] = 1.,
# synapse parameter
tau: float = 8.0,
E: float = 0.,
# synapse delay
delay_step: Union[float, Tensor, Initializer, Callable] = None,
# others
method: str = 'exp_auto',
name: str = None
):
super(ExpCOBA, self).__init__(pre=pre, post=post, conn=conn,
conn_type=conn_type,
g_max=g_max, delay_step=delay_step,
tau=tau, method=method, name=name)

@@ -563,7 +447,13 @@ class ExpCOBA(ExpCUBA):
elif isinstance(self.conn, One2One):
post_vs = delayed_spike * self.g_max
else:
post_vs = bm.pre2post_event_sum(delayed_spike, self.pre2post, self.post.num, self.g_max)
if self.conn_type == 'sparse':
post_vs = bm.pre2post_event_sum(delayed_spike,
self.pre2post,
self.post.num,
self.g_max)
else:
post_vs = delayed_spike @ self.conn_mat

# updates
self.g.value = self.integral(self.g, _t, dt=_dt) + post_vs
@@ -690,27 +580,29 @@ class DualExpCUBA(TwoEndConn):
self.g_max = g_max

# connections
if not isinstance(self.conn, One2One):
self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids')
if not isinstance(self.conn, (One2One, All2All)):
self.conn_mat = self.conn.require('conn_mat')

# variables
if isinstance(self.conn, One2One):
self.h = bm.Variable(bm.zeros(self.post.num))
self.g = bm.Variable(bm.zeros(self.post.num))
elif isinstance(self.conn, All2All):
self.h = bm.Variable(bm.zeros(self.pre.num))
self.g = bm.Variable(bm.zeros(self.pre.num))
else:
self.h = bm.Variable(bm.zeros(self.pre.num))
self.g = bm.Variable(bm.zeros(self.pre_ids.size))
self.h = bm.Variable(bm.zeros(self.pre.num, dtype=bm.float_))
self.g = bm.Variable(bm.zeros(self.pre.num, dtype=bm.float_))
delay_type, delay_step, delay_value = init_delay(delay_step, self.pre.spike)
self.delay_type = delay_type
self.delay_step = delay_step
self.pre_spike = delay_value

# integral
self.int_h = odeint(method=method, f=lambda h, t: -h / self.tau_rise)
self.int_g = odeint(method=method, f=lambda g, t, h: -g / self.tau_decay + h)
self.integral = odeint(method=method, f=self.derivative)

def dh(self, h, t):
return -h / self.tau_rise

def dg(self, g, t, h):
return -g / self.tau_decay + h

@property
def derivative(self):
return JointEq([self.dg, self.dh])

def update(self, _t, _dt):
# delays
@@ -724,21 +616,17 @@ class DualExpCUBA(TwoEndConn):
delayed_pre_spike = self.pre.spike

# post-synaptic values
self.g.value, self.h.value = self.integral(self.g, self.h, _t, _dt)
self.h += delayed_pre_spike
if isinstance(self.conn, All2All):
self.g.value = self.int_g(self.g, _t, self.h, _dt)
self.h.value = self.int_h(self.h, _t, _dt) + delayed_pre_spike
post_vs = self.g.sum()
if not self.conn.include_self:
post_vs = post_vs - self.g
post_vs = self.g_max * post_vs
elif isinstance(self.conn, One2One):
self.g.value = self.int_g(self.g, _t, self.h, _dt)
self.h.value = self.int_h(self.h, _t, _dt) + delayed_pre_spike
post_vs = self.g_max * self.g
else:
self.g.value = self.int_g(self.g, _t, bm.pre2syn(self.h, self.pre_ids), _dt)
self.h.value = self.int_h(self.h, _t, _dt) + delayed_pre_spike
post_vs = self.g_max * bm.syn2post(self.g, self.post_ids, self.post.num)
post_vs = self.g_max * self.g @ self.conn_mat

# output
self.post.input += self.output(post_vs)
@@ -1207,23 +1095,25 @@ class NMDA(TwoEndConn):

# connections
if not isinstance(self.conn, (All2All, One2One)):
self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids')
self.conn_mat = self.conn.require('conn_mat')

# variables
if isinstance(self.conn, All2All):
self.g = bm.Variable(bm.zeros(self.pre.num, dtype=bm.float_))
self.x = bm.Variable(bm.zeros(self.pre.num, dtype=bm.float_))
elif isinstance(self.conn, One2One):
self.g = bm.Variable(bm.zeros(self.post.num, dtype=bm.float_))
self.x = bm.Variable(bm.zeros(self.post.num, dtype=bm.float_))
else:
self.g = bm.Variable(bm.zeros(self.pre_ids.size, dtype=bm.float_))
self.x = bm.Variable(bm.zeros(self.pre.num, dtype=bm.float_))
self.g = bm.Variable(bm.zeros(self.pre.num, dtype=bm.float_))
self.x = bm.Variable(bm.zeros(self.pre.num, dtype=bm.float_))
self.delay_type, self.delay_step, self.pre_spike = init_delay(delay_step, self.pre.spike)

# integral
self.int_g = odeint(method=method, f=lambda g, t, x: -g / self.tau_decay + self.a * x * (1 - g))
self.int_x = odeint(method=method, f=lambda x, t: -x / self.tau_rise)
self.integral = odeint(method=method, f=self.derivative)

def dg(self, g, t, x):
return -g / self.tau_decay + self.a * x * (1 - g)

def dx(self, x, t):
return -x / self.tau_rise

@property
def derivative(self):
return JointEq([self.dg, self.dx])

def update(self, _t, _dt):
# delayed pre-synaptic spikes
@@ -1237,25 +1127,17 @@ class NMDA(TwoEndConn):
delayed_pre_spike = self.pre.spike

# post-synaptic value
self.g.value, self.x.value = self.integral(self.g, self.x, _t, dt=_dt)
self.x += delayed_pre_spike
if isinstance(self.conn, All2All):
x = self.int_x(self.x, _t, dt=_dt)
self.g.value = self.int_g(self.g, _t, self.x, dt=_dt)
self.x.value = x + delayed_pre_spike
post_g = self.g.sum()
if not self.conn.include_self:
post_g = post_g - self.g
elif isinstance(self.conn, One2One):
x = self.int_x(self.x, _t, dt=_dt)
self.g.value = self.int_g(self.g, _t, self.x, dt=_dt)
self.x.value = x + delayed_pre_spike.sum()
post_g = self.g
else:
x = bm.pre2syn(self.x, self.pre_ids)
self.g.value = self.int_g(self.g, _t, x, dt=_dt)
self.x.value = self.int_x(self.x, _t, dt=_dt) + delayed_pre_spike
post_g = bm.syn2post(self.g, self.post_ids, self.post.num)
post_g = self.g @ self.conn_mat

# output
g_inf = 1 + self.cc_Mg / self.beta * bm.exp(-self.alpha * self.post.V)
self.post.input -= self.g_max * post_g * (self.post.V - self.E) / g_inf


+ 12
- 13
brainpy/dyn/synapses/biological_models.py View File

@@ -4,6 +4,7 @@ from typing import Union, Dict

import brainpy.math as bm
from brainpy.connect import TwoEndConnector, All2All, One2One
from brainpy.initialize import Initializer
from brainpy.dyn.base import NeuGroup, TwoEndConn
from brainpy.dyn.utils import init_delay
from brainpy.integrators import odeint
@@ -120,13 +121,13 @@ class AMPA(TwoEndConn):
pre: NeuGroup,
post: NeuGroup,
conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]],
g_max: Parameter = 0.42,
E: Parameter = 0.,
alpha: Parameter = 0.98,
beta: Parameter = 0.18,
T: Parameter = 0.5,
T_duration: Parameter = 0.5,
delay_step: Parameter = None,
g_max: Union[float, Tensor, Initializer] = 0.42,
E: float = 0.,
alpha: float = 0.98,
beta: float = 0.18,
T: float = 0.5,
T_duration: float = 0.5,
delay_step: Union[int, Tensor, Initializer] = None,
method: str = 'exp_auto',
name: str = None
):
@@ -145,7 +146,7 @@ class AMPA(TwoEndConn):
# connection
assert self.conn is not None
if not isinstance(self.conn, (All2All, One2One)):
self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids')
self.conn_mat = self.conn.require('conn_mat')

# variables
if isinstance(self.conn, All2All):
@@ -153,7 +154,7 @@ class AMPA(TwoEndConn):
elif isinstance(self.conn, One2One):
self.g = bm.Variable(bm.zeros(self.post.num))
else:
self.g = bm.Variable(bm.zeros(len(self.pre_ids)))
self.g = bm.Variable(bm.zeros(self.pre.num))
self.spike_arrival_time = bm.Variable(bm.ones(self.pre.num) * -1e7)
self.delay_type, self.delay_step, self.pre_spike = init_delay(delay_step, self.pre.spike)

@@ -190,10 +191,9 @@ class AMPA(TwoEndConn):
if not self.conn.include_self:
g_post = g_post - self.g
else:
syn_sp_times = bm.pre2syn(self.spike_arrival_time, self.pre_ids)
TT = ((_t - syn_sp_times) < self.T_duration) * self.T
TT = ((_t - self.spike_arrival_time) < self.T_duration) * self.T
self.g.value = self.integral(self.g, _t, TT, dt=_dt)
g_post = bm.syn2post(self.g, self.post_ids, self.post.num)
g_post = self.g @ self.conn_mat

# output
self.post.input -= self.g_max * g_post * (self.post.V - self.E)
@@ -275,4 +275,3 @@ class GABAa(AMPA):
T_duration=T_duration,
method=method,
name=name)


+ 2
- 7
brainpy/dyn/synapses/delay_coupling.py View File

@@ -40,9 +40,7 @@ class DelayCoupling(TwoEndConn):

"""

"""Global delay variables. Useful when the same target
variable is used in multiple mappings."""
global_delay_vars: Dict[str, bm.LengthDelay] = dict()


def __init__(
self,
@@ -57,9 +55,6 @@ class DelayCoupling(TwoEndConn):
):
super(DelayCoupling, self).__init__(pre, post, name=name)

# local delay variables
self.local_delay_vars: Dict[str, bm.LengthDelay] = dict()

# domain
if domain not in ['global', 'local']:
raise ValueError('"domain" must be a string in ["global", "local"]. '
@@ -133,7 +128,7 @@ class DelayCoupling(TwoEndConn):
# for later update current value!
self.local_delay_vars[var] = self.global_delay_vars[var]
else:
if self.global_delay_vars[var].delay_len < num_delay_step:
if self.global_delay_vars[var].num_delay_step - 1 < num_delay_step:
variable = source_vars[var]
shape = (num_delay_step,) + variable.shape
delay_data = delay_initializer(shape, dtype=variable.dtype)


+ 2
- 2
brainpy/math/delayvars.py View File

@@ -273,8 +273,8 @@ class LengthDelay(AbstractDelay):
self.idx = Variable(jnp.asarray([0], dtype=jnp.int32))

# delay data
self.data = Variable(jnp.zeros((self.num_delay_step,) + self.shape,
dtype=inits.dtype))
self.data = Variable(jnp.zeros((self.num_delay_step,) + self.shape, dtype=inits.dtype))
self.data[-1] = inits
if delay_data is None:
pass
elif isinstance(delay_data, (ndarray, jnp.ndarray, float, int)):


+ 2
- 2
requirements-win.txt View File

@@ -1,6 +1,6 @@
numpy>=1.15
numpy==1.21
tqdm
numba
numba<=0.55
matplotlib>=3.4
sympy>=1.6
scipy>=1.1.0


Loading…
Cancel
Save