#31 updates

Merged
BrainPy merged 282 commits from updates into master 1 year ago
  1. +0
    -10
      .github/ISSUE_TEMPLATE/feature_request.md
  2. +0
    -4
      .github/PULL_REQUEST_TEMPLATE.md
  3. +2
    -1
      .github/workflows/Linux_CI.yml
  4. +3
    -1
      .github/workflows/MacOS_CI.yml
  5. +17
    -0
      .github/workflows/Publish.yml
  6. +2
    -2
      .github/workflows/Sync_branches.yml
  7. +4
    -3
      .github/workflows/Windows_CI.yml
  8. +0
    -22
      .github/workflows/contributors.yml
  9. +0
    -19
      .github/workflows/generate_changelog.yml
  10. +6
    -1
      .gitignore
  11. +39
    -15
      brainpy/__init__.py
  12. +1
    -0
      brainpy/algorithms/__init__.py
  13. +559
    -0
      brainpy/algorithms/offline.py
  14. +60
    -34
      brainpy/algorithms/online.py
  15. +111
    -0
      brainpy/algorithms/utils.py
  16. +3
    -0
      brainpy/analysis/__init__.py
  17. +12
    -0
      brainpy/analysis/base.py
  18. +9
    -0
      brainpy/analysis/constants.py
  19. +1
    -1
      brainpy/analysis/highdim/__init__.py
  20. +624
    -142
      brainpy/analysis/highdim/slow_points.py
  21. +165
    -0
      brainpy/analysis/highdim/tests/test_slow_points.py
  22. +34
    -24
      brainpy/analysis/lowdim/lowdim_analyzer.py
  23. +2
    -0
      brainpy/analysis/lowdim/tests/test_phase_plane.py
  24. +67
    -10
      brainpy/analysis/utils/measurement.py
  25. +51
    -41
      brainpy/analysis/utils/model.py
  26. +55
    -6
      brainpy/analysis/utils/others.py
  27. +79
    -30
      brainpy/base/base.py
  28. +56
    -51
      brainpy/base/collector.py
  29. +341
    -104
      brainpy/base/io.py
  30. +3
    -2
      brainpy/base/naming.py
  31. +1
    -1
      brainpy/base/tests/test_base.py
  32. +0
    -14
      brainpy/base/tests/test_circular_reference.py
  33. +49
    -70
      brainpy/base/tests/test_collector.py
  34. +170
    -0
      brainpy/base/tests/test_io.py
  35. +0
    -27
      brainpy/compat/__init__.py
  36. +0
    -92
      brainpy/compat/brainobjects.py
  37. +0
    -60
      brainpy/compat/integrators.py
  38. +0
    -61
      brainpy/compat/layers.py
  39. +0
    -98
      brainpy/compat/models.py
  40. +0
    -21
      brainpy/compat/monitor.py
  41. +0
    -65
      brainpy/compat/runners.py
  42. +354
    -127
      brainpy/connect/random_conn.py
  43. +4
    -2
      brainpy/connect/regular_conn.py
  44. +5
    -4
      brainpy/connect/tests/test_random_conn.py
  45. +2
    -1
      brainpy/datasets/__init__.py
  46. +239
    -0
      brainpy/datasets/_internally_replaced_utils.py
  47. +264
    -0
      brainpy/datasets/base.py
  48. +5
    -0
      brainpy/datasets/chaos/__init__.py
  49. +3
    -4
      brainpy/datasets/chaos/chaotic_systems.py
  50. +4
    -0
      brainpy/datasets/vision/__init__.py
  51. +90
    -0
      brainpy/datasets/vision/base.py
  52. +561
    -0
      brainpy/datasets/vision/mnist.py
  53. +479
    -0
      brainpy/datasets/vision/utils.py
  54. +8
    -7
      brainpy/dyn/__init__.py
  55. +999
    -388
      brainpy/dyn/base.py
  56. +1093
    -0
      brainpy/dyn/channels/Ca.py
  57. +0
    -862
      brainpy/dyn/channels/Ca_channels.py
  58. +249
    -0
      brainpy/dyn/channels/IH.py
  59. +0
    -95
      brainpy/dyn/channels/Ih_channels.py
  60. +1021
    -0
      brainpy/dyn/channels/K.py
  61. +127
    -0
      brainpy/dyn/channels/KCa.py
  62. +0
    -148
      brainpy/dyn/channels/K_channels.py
  63. +371
    -0
      brainpy/dyn/channels/Na.py
  64. +0
    -165
      brainpy/dyn/channels/Na_channels.py
  65. +13
    -14
      brainpy/dyn/channels/__init__.py
  66. +121
    -7
      brainpy/dyn/channels/base.py
  67. +90
    -0
      brainpy/dyn/channels/leaky.py
  68. +0
    -75
      brainpy/dyn/channels/leaky_channels.py
  69. +10
    -0
      brainpy/dyn/layers/__init__.py
  70. +220
    -0
      brainpy/dyn/layers/conv.py
  71. +20
    -12
      brainpy/dyn/layers/dropout.py
  72. +190
    -0
      brainpy/dyn/layers/linear.py
  73. +72
    -53
      brainpy/dyn/layers/normalization.py
  74. +202
    -0
      brainpy/dyn/layers/nvar.py
  75. +19
    -17
      brainpy/dyn/layers/pooling.py
  76. +217
    -0
      brainpy/dyn/layers/reservoir.py
  77. +425
    -0
      brainpy/dyn/layers/rnncells.py
  78. +95
    -0
      brainpy/dyn/layers/tests/test_conv.py
  79. +201
    -0
      brainpy/dyn/layers/tests/test_normalization.py
  80. +70
    -0
      brainpy/dyn/layers/tests/test_pooling.py
  81. +1
    -0
      brainpy/dyn/networks/__init__.py
  82. +25
    -0
      brainpy/dyn/networks/cann.py
  83. +2
    -1
      brainpy/dyn/neurons/__init__.py
  84. +458
    -185
      brainpy/dyn/neurons/biological_models.py
  85. +16
    -0
      brainpy/dyn/neurons/compat.py
  86. +76
    -65
      brainpy/dyn/neurons/fractional_models.py
  87. +207
    -0
      brainpy/dyn/neurons/input_groups.py
  88. +0
    -159
      brainpy/dyn/neurons/input_models.py
  89. +22
    -19
      brainpy/dyn/neurons/noise_groups.py
  90. +897
    -362
      brainpy/dyn/neurons/reduced_models.py
  91. +17
    -0
      brainpy/dyn/neurons/tests/test_reduced_models.py
  92. +0
    -3
      brainpy/dyn/rates/__init__.py
  93. +412
    -312
      brainpy/dyn/rates/populations.py
  94. +548
    -228
      brainpy/dyn/runners.py
  95. +4
    -0
      brainpy/dyn/synapses/__init__.py
  96. +404
    -810
      brainpy/dyn/synapses/abstract_models.py
  97. +357
    -103
      brainpy/dyn/synapses/biological_models.py
  98. +258
    -0
      brainpy/dyn/synapses/compat.py
  99. +106
    -88
      brainpy/dyn/synapses/delay_couplings.py
  100. +62
    -0
      brainpy/dyn/synapses/gap_junction.py

+ 0
- 10
.github/ISSUE_TEMPLATE/feature_request.md View File

@@ -1,10 +0,0 @@
---
name: 'Feature Request'
about: 'Suggest a new idea or improvement for Brainpy'
labels: 'enhancement'
---

Please:

- [ ] Check for duplicate requests.
- [ ] Describe your goal, and if possible provide a code snippet with a motivating example.

+ 0
- 4
.github/PULL_REQUEST_TEMPLATE.md View File

@@ -14,10 +14,6 @@
<!--- For example, markdown files should pass markdownlint locally according to the rules -->
<!--- See how your change affects other areas of the code, etc. -->

## Screenshots(optional)
<!--- If Screenshots is not necessary or not available in this pull request, you can delete this section -->
<!--- Changes including html and css are required to have screenshots -->

## Types of changes
<!--- What types of changes does your code introduce? -->
<!--- Only left the line that best describes this pull request -->


+ 2
- 1
.github/workflows/Linux_CI.yml View File

@@ -28,6 +28,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
# python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.14.tar.gz
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
python setup.py install
- name: Lint with flake8
@@ -35,7 +36,7 @@ jobs:
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest brainpy/

+ 3
- 1
.github/workflows/MacOS_CI.yml View File

@@ -28,6 +28,8 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
python -m pip install jax==0.3.14
python -m pip install jaxlib==0.3.14
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
python setup.py install
- name: Lint with flake8
@@ -35,7 +37,7 @@ jobs:
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest brainpy/

+ 17
- 0
.github/workflows/Publish.yml View File

@@ -0,0 +1,17 @@
name: Publish to PyPI.org
on:
release:
types: [published]
jobs:
pypi:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
with:
fetch-depth: 0
- run: python setup.py bdist_wheel
- name: Publish package
uses: pypa/gh-action-pypi-publish@release/v1
with:
password: ${{ secrets.PYPI_API_TOKEN }}

+ 2
- 2
.github/workflows/Sync_branches.yml View File

@@ -9,10 +9,10 @@ jobs:
steps:
- uses: actions/checkout@master

- name: Merge master -> brainpy-2.x
- name: Merge master -> brainpy-2.2.x
uses: devmasx/merge-branch@master
with:
type: now
from_branch: master
target_branch: brainpy-2.x
target_branch: brainpy-2.2.x
github_token: ${{ github.token }}

+ 4
- 3
.github/workflows/Windows_CI.yml View File

@@ -28,8 +28,9 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
python -m pip install numpy==1.21.0
python -m pip install "jax[cpu]==0.3.5" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
python -m pip install numpy>=1.21.0
python -m pip install "jaxlib==0.3.14" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.14.tar.gz
python -m pip install -r requirements-win.txt
python -m pip install tqdm brainpylib
python setup.py install
@@ -38,7 +39,7 @@ jobs:
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest brainpy/

+ 0
- 22
.github/workflows/contributors.yml View File

@@ -1,22 +0,0 @@
name: Add contributors
on:
schedule:
- cron: '20 20 * * *'
push:
branches: [ master ]

jobs:
add-contributors:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: BobAnkh/add-contributors@V2.1.0
with:
CONTRIBUTOR: '# Contributors'
COLUMN_PER_ROW: '6'
ACCESS_TOKEN: ${{secrets.GITHUB_TOKEN}}
IMG_WIDTH: '100'
FONT_SIZE: '14'
PATH: '/README.md'
COMMIT_MESSAGE: 'docs(README): update contributors'
AVATAR_SHAPE: 'round'

+ 0
- 19
.github/workflows/generate_changelog.yml View File

@@ -1,19 +0,0 @@
name: Generate changelog
on:
release:
types: [created, edited]

jobs:
generate-changelog:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- uses: BobAnkh/auto-generate-changelog@master
with:
REPO_NAME: 'PKU-NIP-Lab/BrainPy'
ACCESS_TOKEN: ${{secrets.GITHUB_TOKEN}}
PATH: 'CHANGELOG.rst'
COMMIT_MESSAGE: 'docs(CHANGELOG): update release notes'
TYPE: 'feat:Feature,fix:Bug Fixes,docs:Documentation,refactor:Refactor,perf:Performance Improvements'

+ 6
- 1
.gitignore View File

@@ -1,15 +1,20 @@
publishment.md
#experimental/
.vscode
io_test_tmp*

brainpy/base/tests/io_test_tmp*

development

brainpy/dyn/tests/data
examples/simulation/data
examples/simulation/results
examples/analysis/data
extensions/.idea
extensions/wheelhouse
extensions/dist
extensions/win_dll
extensions/fixed_wheels
extensions/build
extensions/cmake-build-debug
@@ -53,7 +58,6 @@ develop/benchmark/CUBA/annarchy*
develop/benchmark/CUBA/brian2*



*~
\#*\#
*.pyc
@@ -210,3 +214,4 @@ dmypy.json
cython_debug/

/docs/apis/simulation/generated/
!/brainpy/dyn/tests/data/

+ 39
- 15
brainpy/__init__.py View File

@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "2.1.9"
__version__ = "2.2.1"


try:
@@ -15,7 +15,7 @@ except ModuleNotFoundError:


# fundamental modules
from . import errors, tools, check
from . import errors, tools, check, modes


# "base" module
@@ -29,46 +29,70 @@ from . import math


# toolboxes
from . import connect, initialize, optimizers, measure, losses, datasets, inputs
from . import (connect, # synaptic connection
initialize, # weight initialization
optimizers, # gradient descent optimizers
losses, # loss functions
measure, # methods for data analysis
datasets, # methods for generating data
inputs, # methods for generating input currents
algorithms, # online or offline training algorithms
)


# numerical integrators
from . import integrators
from .integrators import ode
from .integrators import sde
from .integrators import dde
from .integrators import fde
from .integrators.ode import odeint
from .integrators.sde import sdeint
from .integrators.dde import ddeint
from .integrators.fde import fdeint
from .integrators.joint_eq import JointEq


# dynamics simulation
from . import dyn
from .dyn import (channels, # channel models
layers, # ANN layers
networks, # network models
neurons, # neuron groups
rates, # rate models
synapses, # synaptic dynamics
synouts, # synaptic output
synplast, # synaptic plasticity
)
from brainpy.dyn.base import (DynamicalSystem,
Container,
Sequential,
Network,
NeuGroup,
SynConn,
SynOut,
SynSTP,
SynLTP,
TwoEndConn,
CondNeuGroup,
Channel,)
from .dyn.runners import *


# dynamics training
from . import train


# neural networks modeling
from . import nn
# automatic dynamics analysis
from . import analysis


# running
from . import running


# automatic dynamics analysis
from . import analysis


# "visualization" module, will be removed soon
from .visualization import visualize


# compatible interface
from .compat import * # compat


# convenient access
conn = connect
init = initialize


brainpy/nn/algorithms/__init__.py → brainpy/algorithms/__init__.py View File

@@ -2,3 +2,4 @@

from .offline import *
from .online import *
from . import utils

+ 559
- 0
brainpy/algorithms/offline.py View File

@@ -0,0 +1,559 @@
# -*- coding: utf-8 -*-

import warnings

import numpy as np
from jax.lax import while_loop

import brainpy.math as bm
from brainpy.base import Base
from brainpy.types import Array
from .utils import (Sigmoid,
Regularization, L1Regularization, L1L2Regularization, L2Regularization,
polynomial_features, normalize)

__all__ = [
# base class for offline training algorithm
'OfflineAlgorithm',

# training methods
'LinearRegression',
'RidgeRegression',
'LassoRegression',
'LogisticRegression',
'PolynomialRegression',
'PolynomialRidgeRegression',
'ElasticNetRegression',

# general supports
'get_supported_offline_methods',
'register_offline_method',
]

name2func = dict()


class OfflineAlgorithm(Base):
"""Base class for offline training algorithm."""

def __init__(self, name=None):
super(OfflineAlgorithm, self).__init__(name=name)

def __call__(self, identifier, target, input, output):
"""The training procedure.

Parameters
----------
identifier: str
The variable name.
target: JaxArray, ndarray
The 2d target data with the shape of `(num_batch, num_output)`.
input: JaxArray, ndarray
The 2d input data with the shape of `(num_batch, num_input)`.
output: JaxArray, ndarray
The 2d output data with the shape of `(num_batch, num_output)`.

Returns
-------
weight: JaxArray
The weights after fit.
"""
return self.call(identifier, target, input, output)

def call(self, identifier, targets, inputs, outputs) -> Array:
"""The training procedure.

Parameters
----------
identifier: str
The identifier.

inputs: JaxArray, jax.numpy.ndarray, numpy.ndarray
The 3d input data with the shape of `(num_batch, num_time, num_input)`,
or, the 2d input data with the shape of `(num_time, num_input)`.

targets: JaxArray, jax.numpy.ndarray, numpy.ndarray
The 3d target data with the shape of `(num_batch, num_time, num_output)`,
or the 2d target data with the shape of `(num_time, num_output)`.

outputs: JaxArray, jax.numpy.ndarray, numpy.ndarray
The 3d output data with the shape of `(num_batch, num_time, num_output)`,
or the 2d output data with the shape of `(num_time, num_output)`.

Returns
-------
weight: JaxArray
The weights after fit.
"""
raise NotImplementedError('Must implement the __call__ function by the subclass itself.')

def __repr__(self):
return self.__class__.__name__

def initialize(self, identifier, *args, **kwargs):
pass


def _check_data_2d_atls(x):
if x.ndim < 2:
raise ValueError(f'Data must be a 2d tensor. But we got {x.ndim}d: {x.shape}.')
if x.ndim != 2:
return x.reshape((-1, x.shape[-1]))
else:
return x


class RegressionAlgorithm(OfflineAlgorithm):
""" Base regression model. Models the relationship between a scalar dependent variable y and the independent
variables X.

Parameters
----------
max_iter: int
The number of training iterations the algorithm will tune the weights for.
learning_rate: float
The step length that will be used when updating the weights.
"""

def __init__(
self,
max_iter: int = None,
learning_rate: float = None,
regularizer: Regularization = None,
name: str = None
):
super(RegressionAlgorithm, self).__init__(name=name)
self.max_iter = max_iter
self.learning_rate = learning_rate
self.regularizer = regularizer

def initialize(self, identifier, *args, **kwargs):
pass

def init_weights(self, n_features, n_out):
""" Initialize weights randomly [-1/N, 1/N] """
limit = 1 / np.sqrt(n_features)
return bm.random.uniform(-limit, limit, (n_features, n_out))

def gradient_descent_solve(self, targets, inputs, outputs=None):
# checking
inputs = _check_data_2d_atls(bm.asarray(inputs))
targets = _check_data_2d_atls(bm.asarray(targets))

# initialize weights
w = self.init_weights(inputs.shape[1], targets.shape[1])

def cond_fun(a):
i, par_old, par_new = a
return bm.logical_and(bm.logical_not(bm.allclose(par_old, par_new)),
i < self.max_iter).value

def body_fun(a):
i, _, par_new = a
# Gradient of regularization loss w.r.t w
y_pred = inputs.dot(par_new)
grad_w = bm.dot(inputs.T, -(targets - y_pred)) + self.regularizer.grad(par_new)
# Update the weights
par_new2 = par_new - self.learning_rate * grad_w
return i + 1, par_new, par_new2

# Tune parameters for n iterations
r = while_loop(cond_fun, body_fun, (0, w - 1e-8, w))
return r[-1]

def predict(self, W, X):
return bm.dot(X, W)


class LinearRegression(RegressionAlgorithm):
"""Training algorithm of least-square regression.

Parameters
----------
name: str
The name of the algorithm.
"""

def __init__(
self,
name: str = None,

# parameters for using gradient descent
max_iter: int = 1000,
learning_rate: float = 0.001,
gradient_descent: bool = False,
):
super(LinearRegression, self).__init__(name=name,
max_iter=max_iter,
learning_rate=learning_rate,
regularizer=Regularization(0.))
self.gradient_descent = gradient_descent

def call(self, identifier, targets, inputs, outputs=None):
# checking
inputs = _check_data_2d_atls(bm.asarray(inputs))
targets = _check_data_2d_atls(bm.asarray(targets))

# solving
if self.gradient_descent:
return self.gradient_descent_solve(targets, inputs)
else:
weights = bm.linalg.lstsq(inputs, targets)
return weights[0]


name2func['linear'] = LinearRegression
name2func['lstsq'] = LinearRegression


class RidgeRegression(RegressionAlgorithm):
"""Training algorithm of ridge regression.

Parameters
----------
alpha: float
The regularization coefficient.

.. versionadded:: 2.2.0

beta: float
The regularization coefficient.

.. deprecated:: 2.2.0
Please use `alpha` to set regularization factor.

name: str
The name of the algorithm.
"""

def __init__(
self,
alpha: float = 1e-7,
beta: float = None,
name: str = None,

# parameters for using gradient descent
max_iter: int = 1000,
learning_rate: float = 0.001,
gradient_descent: bool = False,
):
if beta is not None:
warnings.warn(f"Please use 'alpha' to set regularization factor. "
f"'beta' has been deprecated since version 2.2.0.",
UserWarning)
alpha = beta
super(RidgeRegression, self).__init__(name=name,
max_iter=max_iter,
learning_rate=learning_rate,
regularizer=L2Regularization(alpha=alpha))
self.gradient_descent = gradient_descent

def call(self, identifier, targets, inputs, outputs=None):
# checking
inputs = _check_data_2d_atls(bm.asarray(inputs))
targets = _check_data_2d_atls(bm.asarray(targets))

# solving
if self.gradient_descent:
return self.gradient_descent_solve(targets, inputs)
else:
temp = inputs.T @ inputs
if self.regularizer.alpha > 0.:
temp += self.regularizer.alpha * bm.eye(inputs.shape[-1])
weights = bm.linalg.pinv(temp) @ (inputs.T @ targets)
return weights

def __repr__(self):
return f'{self.__class__.__name__}(beta={self.regularizer.alpha})'


name2func['ridge'] = RidgeRegression


class LassoRegression(RegressionAlgorithm):
"""Lasso regression method for offline training.

Parameters
----------
alpha: float
Constant that multiplies the L1 term. Defaults to 1.0.
`alpha = 0` is equivalent to an ordinary least square.
max_iter: int
The maximum number of iterations.
degree: int
The degree of the polynomial that the independent variable X will be transformed to.
name: str
The name of the algorithm.
"""

def __init__(
self,
alpha: float = 1.0,
degree: int = 2,
add_bias: bool = False,
name: str = None,

# parameters for using gradient descent
max_iter: int = 1000,
learning_rate: float = 0.001,
gradient_descent: bool = True,
):
super(LassoRegression, self).__init__(name=name,
max_iter=max_iter,
learning_rate=learning_rate,
regularizer=L1Regularization(alpha=alpha))
self.gradient_descent = gradient_descent
self.add_bias = add_bias
assert gradient_descent
self.degree = degree

def call(self, identifier, targets, inputs, outputs=None):
# checking
inputs = _check_data_2d_atls(bm.asarray(inputs))
targets = _check_data_2d_atls(bm.asarray(targets))

# solving
inputs = normalize(polynomial_features(inputs, degree=self.degree, add_bias=self.add_bias))
return super(LassoRegression, self).gradient_descent_solve(targets, inputs)

def predict(self, W, X):
X = _check_data_2d_atls(bm.asarray(X))
X = normalize(polynomial_features(X, degree=self.degree, add_bias=self.add_bias))
return super(LassoRegression, self).predict(W, X)


name2func['lasso'] = LassoRegression


class LogisticRegression(RegressionAlgorithm):
"""Logistic regression method for offline training.

Parameters
----------
learning_rate: float
The step length that will be taken when following the negative gradient during
training.
gradient_descent: boolean
True or false depending on if gradient descent should be used when training. If
false then we use batch optimization by least squares.
max_iter: int
The number of iteration to optimize the parameters.
name: str
The name of the algorithm.
"""

def __init__(
self,
learning_rate: float = .1,
gradient_descent: bool = True,
max_iter: int = 4000,
name: str = None,
):
super(LogisticRegression, self).__init__(name=name,
max_iter=max_iter,
learning_rate=learning_rate)
self.gradient_descent = gradient_descent
self.sigmoid = Sigmoid()

def call(self, identifier, targets, inputs, outputs=None) -> Array:
# prepare data
inputs = _check_data_2d_atls(bm.asarray(inputs))
targets = _check_data_2d_atls(bm.asarray(targets))
if targets.shape[-1] != 1:
raise ValueError(f'Target must be a scalar, but got multiple variables: {targets.shape}. ')
targets = targets.flatten()

# initialize parameters
param = self.init_weights(inputs.shape[1], targets.shape[1])

def cond_fun(a):
i, par_old, par_new = a
return bm.logical_and(bm.logical_not(bm.allclose(par_old, par_new)),
i < self.max_iter).value

def body_fun(a):
i, par_old, par_new = a
# Make a new prediction
y_pred = self.sigmoid(inputs.dot(par_new))
if self.gradient_descent:
# Move against the gradient of the loss function with
# respect to the parameters to minimize the loss
par_new2 = par_new - self.learning_rate * (y_pred - targets).dot(inputs)
else:
gradient = self.sigmoid.grad(inputs.dot(par_new))
diag_grad = bm.zeros((gradient.size, gradient.size))
diag = bm.arange(gradient.size)
diag_grad[diag, diag] = gradient
par_new2 = bm.linalg.pinv(inputs.T.dot(diag_grad).dot(inputs)).dot(inputs.T).dot(
diag_grad.dot(inputs).dot(par_new) + targets - y_pred)
return i + 1, par_new, par_new2

# Tune parameters for n iterations
r = while_loop(cond_fun, body_fun, (0, param+1., param))
return r[-1]

def predict(self, W, X):
return self.sigmoid(X @ W)


name2func['logistic'] = LogisticRegression


class PolynomialRegression(LinearRegression):
def __init__(
self,
degree: int = 2,
name: str = None,
add_bias: bool = False,

# parameters for using gradient descent
max_iter: int = 1000,
learning_rate: float = 0.001,
gradient_descent: bool = True,
):
super(PolynomialRegression, self).__init__(name=name,
max_iter=max_iter,
learning_rate=learning_rate,
gradient_descent=gradient_descent)
self.degree = degree
self.add_bias = add_bias

def call(self, identifier, targets, inputs, outputs=None):
inputs = _check_data_2d_atls(bm.asarray(inputs))
targets = _check_data_2d_atls(bm.asarray(targets))
inputs = polynomial_features(inputs, degree=self.degree, add_bias=self.add_bias)
return super(PolynomialRegression, self).call(identifier, targets, inputs)

def predict(self, W, X):
X = _check_data_2d_atls(bm.asarray(X))
X = polynomial_features(X, degree=self.degree, add_bias=self.add_bias)
return super(PolynomialRegression, self).predict(W, X)


name2func['polynomial'] = PolynomialRegression


class PolynomialRidgeRegression(RidgeRegression):
def __init__(
self,
alpha: float = 1.0,
degree: int = 2,
name: str = None,
add_bias: bool = False,

# parameters for using gradient descent
max_iter: int = 1000,
learning_rate: float = 0.001,
gradient_descent: bool = True,
):
super(PolynomialRidgeRegression, self).__init__(alpha=alpha,
name=name,
max_iter=max_iter,
learning_rate=learning_rate,
gradient_descent=gradient_descent)
self.degree = degree
self.add_bias = add_bias

def call(self, identifier, targets, inputs, outputs=None):
# checking
inputs = _check_data_2d_atls(bm.asarray(inputs))
targets = _check_data_2d_atls(bm.asarray(targets))
inputs = polynomial_features(inputs, degree=self.degree, add_bias=self.add_bias)
return super(PolynomialRidgeRegression, self).call(identifier, targets, inputs)

def predict(self, W, X):
X = _check_data_2d_atls(bm.asarray(X))
X = polynomial_features(X, degree=self.degree, add_bias=self.add_bias)
return super(PolynomialRidgeRegression, self).predict(W, X)


name2func['polynomial_ridge'] = PolynomialRidgeRegression


class ElasticNetRegression(RegressionAlgorithm):
"""

Parameters:
-----------
degree: int
The degree of the polynomial that the independent variable X will be transformed to.
reg_factor: float
The factor that will determine the amount of regularization and feature
shrinkage.
l1_ration: float
Weighs the contribution of l1 and l2 regularization.
n_iterations: float
The number of training iterations the algorithm will tune the weights for.
learning_rate: float
The step length that will be used when updating the weights.
"""

def __init__(
self,
alpha: float = 1.0,
degree: int = 2,
l1_ratio: float = 0.5,
name: str = None,
add_bias: bool = False,

# parameters for using gradient descent
max_iter: int = 1000,
learning_rate: float = 0.001,
gradient_descent: bool = True,
):
super(ElasticNetRegression, self).__init__(
name=name,
max_iter=max_iter,
learning_rate=learning_rate,
regularizer=L1L2Regularization(alpha=alpha, l1_ratio=l1_ratio)
)
self.degree = degree
self.add_bias = add_bias
self.gradient_descent = gradient_descent
assert gradient_descent

def call(self, identifier, targets, inputs, outputs=None):
# checking
inputs = _check_data_2d_atls(bm.asarray(inputs))
targets = _check_data_2d_atls(bm.asarray(targets))
# solving
inputs = normalize(polynomial_features(inputs, degree=self.degree))
return super(ElasticNetRegression, self).gradient_descent_solve(targets, inputs)

def predict(self, W, X):
X = _check_data_2d_atls(bm.asarray(X))
X = normalize(polynomial_features(X, degree=self.degree, add_bias=self.add_bias))
return super(ElasticNetRegression, self).predict(W, X)


name2func['elastic_net'] = ElasticNetRegression


def get_supported_offline_methods():
"""Get all supported offline training methods."""
return tuple(name2func.keys())


def register_offline_method(name: str, method: OfflineAlgorithm):
"""Register a new offline learning method.

Parameters
----------
name: str
The method name.
method: OfflineAlgorithm
The function method.
"""
if name in name2func:
raise ValueError(f'"{name}" has been registered in offline training methods.')
if not isinstance(method, OfflineAlgorithm):
raise ValueError(f'"method" must be an instance {OfflineAlgorithm.__name__}, but we got {type(method)}')
name2func[name] = method


def get(name: str) -> OfflineAlgorithm:
"""Get the training function according to the training method name."""
if name not in name2func:
raise ValueError(f'All offline methods are: {get_supported_offline_methods()}.\n'
f'But we got {name}.')
return name2func[name]

brainpy/nn/algorithms/online.py → brainpy/algorithms/online.py View File

@@ -2,13 +2,14 @@

import brainpy.math as bm
from brainpy.base import Base
from jax import vmap
import jax.numpy as jnp

__all__ = [
# base class
'OnlineAlgorithm',

# online learning algorithms
'ForceLearning',
'RLS',
'LMS',

@@ -26,12 +27,12 @@ class OnlineAlgorithm(Base):
def __init__(self, name=None):
super(OnlineAlgorithm, self).__init__(name=name)

def __call__(self, name, target, input, output):
def __call__(self, identifier, target, input, output):
"""The training procedure.

Parameters
----------
name: str
identifier: str
The variable name.
target: JaxArray, ndarray
The 2d target data with the shape of `(num_batch, num_output)`.
@@ -45,18 +46,17 @@ class OnlineAlgorithm(Base):
weight: JaxArray
The weights after fit.
"""
return self.call(name, target, input, output)
return self.call(identifier, target, input, output)

def initialize(self, name, *args, **kwargs):
raise NotImplementedError('Must implement the initialize() '
'function by the subclass itself.')
def initialize(self, identifier, *args, **kwargs):
pass

def call(self, name, target, input, output):
def call(self, identifier, target, input, output):
"""The training procedure.

Parameters
----------
name: str
identifier: str
The variable name.
target: JaxArray, ndarray
The 2d target data with the shape of `(num_batch, num_output)`.
@@ -77,7 +77,26 @@ class OnlineAlgorithm(Base):


class RLS(OnlineAlgorithm):
"""The recursive least squares (RLS)."""
"""The recursive least squares (RLS) algorithm.

RLS is an adaptive filter algorithm that recursively finds the
coefficients that minimize a weighted linear least squares cost
function relating to the input signals. This approach is in
contrast to other algorithms such as the least mean squares
(LMS) that aim to reduce the mean square error.

See Also
--------
LMS, ForceLearning

Parameters
----------
alpha: float
The learning rate.
name: str
The algorithm name.

"""

postfix = '.rls.P'

@@ -85,13 +104,13 @@ class RLS(OnlineAlgorithm):
super(RLS, self).__init__(name=name)
self.alpha = alpha

def initialize(self, name, feature_in, feature_out=None):
name = name + self.postfix
self.implicit_vars[name] = bm.Variable(bm.eye(feature_in) * self.alpha)
def initialize(self, identifier, feature_in, feature_out=None):
identifier = identifier + self.postfix
self.implicit_vars[identifier] = bm.Variable(bm.eye(feature_in) * self.alpha)

def call(self, name, target, input, output):
name = name + self.postfix
P = self.implicit_vars[name]
def call(self, identifier, target, input, output):
identifier = identifier + self.postfix
P = self.implicit_vars[identifier]
# update the inverse correlation matrix
k = bm.dot(P, input.T) # (num_input, num_batch)
hPh = bm.dot(input, k) # (num_batch, num_batch)
@@ -106,25 +125,33 @@ class RLS(OnlineAlgorithm):
name2func['rls'] = RLS


class ForceLearning(RLS):
postfix = '.force.P'


name2func['force'] = ForceLearning
class LMS(OnlineAlgorithm):
"""The least mean squares (LMS).

LMS algorithms are a class of adaptive filter used to mimic a desired filter
by finding the filter coefficients that relate to producing the least mean
square of the error signal (difference between the desired and the actual signal).
It is a stochastic gradient descent method in that the filter is only adapted
based on the error at the current time. It was invented in 1960 by
Stanford University professor Bernard Widrow and his first Ph.D. student, Ted Hoff.

class LMS(OnlineAlgorithm):
"""The least mean squares (LMS). """
Parameters
----------
alpha: float
The learning rate.
name: str
The target name.
"""

def __init__(self, alpha=0.1, name=None):
super(LMS, self).__init__(name=name)
self.alpha = alpha

def initialize(self, name, *args, **kwargs):
pass
def call(self, name, target, input, output):
return -self.alpha * bm.dot(output - target, output)
def call(self, identifier, target, input, output):
assert target.shape[0] == input.shape[0] == output.shape[0], 'Batch size should be consistent.'
error = bm.as_jax(output - target)
input = bm.as_jax(input)
return -self.alpha * bm.sum(vmap(jnp.outer)(input, error), axis=0)


name2func['lms'] = LMS
@@ -135,7 +162,7 @@ def get_supported_online_methods():
return tuple(name2func.keys())


def register_online_method(name, method):
def register_online_method(name: str, method: OnlineAlgorithm):
"""Register a new oneline learning method.

Parameters
@@ -146,14 +173,13 @@ def register_online_method(name, method):
The function method.
"""
if name in name2func:
raise ValueError(f'"{name}" has been registered in offline training methods.')
if not callable(method):
raise ValueError(f'"method" must be an instance of callable '
f'function, but we got {type(method)}')
raise ValueError(f'"{name}" has been registered in online training methods. Please change another name.')
if not isinstance(method, OnlineAlgorithm):
raise ValueError(f'"method" must be an instance of {OnlineAlgorithm.__name__}, but we got {type(method)}')
name2func[name] = method


def get(name):
def get(name: str):
"""Get the training function according to the training method name."""
if name not in name2func:
raise ValueError(f'All online methods are: {get_supported_online_methods()}.\n'

+ 111
- 0
brainpy/algorithms/utils.py View File

@@ -0,0 +1,111 @@
# -*- coding: utf-8 -*-

import brainpy.math as bm

from itertools import combinations_with_replacement

__all__ = [
'Sigmoid',
'Regularization',
'L1Regularization',
'L2Regularization',
'L1L2Regularization',

'polynomial_features',
'normalize',
]


class Sigmoid(object):
def __call__(self, x):
return 1 / (1 + bm.exp(-x))

def grad(self, x):
exp = bm.exp(-x)
return exp / (1 + exp) ** 2


class Regularization(object):
def __init__(self, alpha):
self.alpha = alpha

def __call__(self, x):
return 0

def grad(self, x):
return 0


class L1Regularization(Regularization):
"""L1 Regularization."""

def __init__(self, alpha):
super(L1Regularization, self).__init__(alpha=alpha)

def __call__(self, w):
return self.alpha * bm.linalg.norm(w)

def grad(self, w):
return self.alpha * bm.sign(w)


class L2Regularization(Regularization):
"""L2 Regularization."""

def __init__(self, alpha):
super(L2Regularization, self).__init__(alpha=alpha)

def __call__(self, w):
return self.alpha * 0.5 * w.T.dot(w)

def grad(self, w):
return self.alpha * w


class L1L2Regularization(Regularization):
"""L1 and L2 Regularization."""

def __init__(self, alpha, l1_ratio=0.5):
super(L1L2Regularization, self).__init__(alpha=alpha)
self.l1_ratio = l1_ratio

def __call__(self, w):
l1_contr = self.l1_ratio * bm.linalg.norm(w)
l2_contr = (1 - self.l1_ratio) * 0.5 * w.T.dot(w)
return self.alpha * (l1_contr + l2_contr)

def grad(self, w):
l1_contr = self.l1_ratio * bm.sign(w)
l2_contr = (1 - self.l1_ratio) * w
return self.alpha * (l1_contr + l2_contr)


def index_combinations(n_features, degree):
combs = [combinations_with_replacement(range(n_features), i) for i in range(2, degree + 1)]
flat_combs = [item for sublist in combs for item in sublist]
return flat_combs


def polynomial_features(X, degree: int, add_bias: bool = True):
n_samples, n_features = X.shape
combinations = index_combinations(n_features, degree)
if len(combinations) == 0:
return bm.insert(X, 0, 1, axis=1) if add_bias else X
if add_bias:
n_features += 1
X_new = bm.zeros((n_samples, 1 + n_features + len(combinations)))
if add_bias:
X_new[:, 0] = 1
X_new[:, 1:n_features] = X
else:
X_new[:, :n_features] = X
for i, index_combs in enumerate(combinations):
X_new[:, n_features + i] = bm.prod(X[:, index_combs], axis=1)
return X_new


def normalize(X, axis=-1, order=2):
""" Normalize the dataset X """
l2 = bm.atleast_1d(bm.linalg.norm(X, order, axis))
l2 = bm.where(l2 == 0, 1, l2)
return X / bm.expand_dims(l2, axis)

+ 3
- 0
brainpy/analysis/__init__.py View File

@@ -14,11 +14,14 @@ This module provides analysis tools for differential equations.
Details in the following.
"""

from .base import *

from .highdim.slow_points import *

from .lowdim.lowdim_phase_plane import *
from .lowdim.lowdim_bifurcation import *

from .constants import *
from . import constants as C
from . import stability
from . import utils

+ 12
- 0
brainpy/analysis/base.py View File

@@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-


__all__ = [
'DSAnalyzer'
]


class DSAnalyzer(object):
"""Base class of analyzers for dynamical systems in BrainPy"""
pass


+ 9
- 0
brainpy/analysis/constants.py View File

@@ -1,6 +1,15 @@
# -*- coding: utf-8 -*-


__all__ = [
'CONTINUOUS',
'DISCRETE',
]


CONTINUOUS = 'continuous'
DISCRETE = 'discrete'

F_vmap_fx = 'F_vmap_fx'
F_vmap_fy = 'F_vmap_fy'
F_vmap_brentq_fx = 'F_vmap_brentq_fx'


+ 1
- 1
brainpy/analysis/highdim/__init__.py View File

@@ -1,3 +1,3 @@
# -*- coding: utf-8 -*-

from .slow_points import *
from .slow_points import *

+ 624
- 142
brainpy/analysis/highdim/slow_points.py View File

@@ -1,25 +1,38 @@
# -*- coding: utf-8 -*-

import math
import time
import warnings
from functools import partial
from typing import Callable, Union, Dict, Sequence, Tuple

from jax import vmap
import jax.numpy
import jax.numpy as jnp
import numpy as np
from jax import vmap
from jax.scipy.optimize import minimize
from jax.tree_util import tree_flatten, tree_map

import brainpy.math as bm
from brainpy import optimizers as optim
from brainpy.analysis import utils
from brainpy.errors import AnalyzerError
from brainpy import optimizers as optim, losses
from brainpy.analysis import utils, base, constants
from brainpy.base import TensorCollector
from brainpy.dyn.base import DynamicalSystem
from brainpy.dyn.runners import build_inputs, check_and_format_inputs
from brainpy.errors import AnalyzerError, UnsupportedError
from brainpy.tools.others.dicts import DotDict
from brainpy.types import Array

__all__ = [
'SlowPointFinder',
]

F_OPT_SOLVER = 'function_for_opt_solver'
F_GRADIENT_DESCENT = 'function_for_gradient_descent'

SUPPORTED_OPT_SOLVERS = {
'BFGS': lambda f, x0: minimize(f, x0, method='BFGS')
}

class SlowPointFinder(object):

class SlowPointFinder(base.DSAnalyzer):
"""Find fixed/slow points by numerical optimization.

This class can help you:
@@ -29,197 +42,415 @@ class SlowPointFinder(object):
- exclude any non-unique fixed points according to a tolerance
- exclude any far-away "outlier" fixed points

This model implementation is inspired by https://github.com/google-research/computation-thru-dynamics.

Parameters
----------
f_cell : callable, function
The function to compute the recurrent units.
f_cell : callable, function, DynamicalSystem
The target of computing the recurrent units.

f_type : str
The system's type: continuous system or discrete system.

- 'continuous': continuous derivative function, denotes this is a continuous system, or
- 'discrete': discrete update function, denotes this is a discrete system.

verbose : bool
Whether output the optimization progress.

f_loss: callable
The loss function.
- If ``f_type`` is `"discrete"`, the loss function must receive three arguments, i.e.,
``loss(outputs, targets, axis)``.
- If ``f_type`` is `"continuous"`, the loss function must receive two arguments, i.e.,
``loss(outputs, axis)``.

.. versionadded:: 2.2.0

t: float
Parameter for `f_cell` is instance of :py:class:`~.DynamicalSystem`.
The time to evaluate the fixed points. Default is 0.

.. versionadded:: 2.2.0

dt: float
Parameter for `f_cell` is instance of :py:class:`~.DynamicalSystem`.
The numerical integration step, which can be used when .
The default is given by `brainpy.math.get_dt()`.

.. versionadded:: 2.2.0

inputs: sequence
Parameter for `f_cell` is instance of :py:class:`~.DynamicalSystem`.
Same as ``inputs`` in :py:class:`~.DSRunner`.

.. versionadded:: 2.2.0

excluded_vars: sequence, dict
Parameter for `f_cell` is instance of :py:class:`~.DynamicalSystem`.
The excluded variables (can be a sequence of `Variable` instances).
These variables will not be included for optimization of fixed points.

.. versionadded:: 2.2.0

target_vars: dict
Parameter for `f_cell` is instance of :py:class:`~.DynamicalSystem`.
The target variables (can be a dict of `Variable` instances).
These variables will be included for optimization of fixed points.
The candidate points later provided should have same keys as in ``target_vars``.

.. versionadded:: 2.2.0

f_loss_batch : callable, function
Parameter for `f_cell` is instance of :py:class:`~.DynamicalSystem`.
The function to compute the loss.
verbose : bool
Whether print the optimization progress.
"""

def __init__(self, f_cell, f_type='continuous', f_loss_batch=None, verbose=True):
self.verbose = verbose
if f_type not in ['discrete', 'continuous']:
raise AnalyzerError(f'Only support "continuous" (continuous derivative function) or '
f'"discrete" (discrete update function), not {f_type}.')
.. deprecated:: 2.2.0
Has been removed. Please use ``f_loss`` to set different loss function.

# functions
self.f_cell = f_cell
if f_loss_batch is None:
if f_type == 'discrete':
self.f_loss = bm.jit(lambda h: bm.mean((h - f_cell(h)) ** 2))
self.f_loss_batch = bm.jit(lambda h: bm.mean((h - vmap(f_cell)(h)) ** 2, axis=1))
if f_type == 'continuous':
self.f_loss = bm.jit(lambda h: bm.mean(f_cell(h) ** 2))
self.f_loss_batch = bm.jit(lambda h: bm.mean((vmap(f_cell)(h)) ** 2, axis=1))
"""

def __init__(
self,
f_cell: Union[Callable, DynamicalSystem],
f_type: str = None,
f_loss: Callable = None,
verbose: bool = True,
args: Tuple = (),

# parameters for `f_cell` is DynamicalSystem instance
inputs: Sequence = None,
fun_inputs: Callable = None,
t: float = None,
dt: float = None,
target_vars: Dict[str, bm.Variable] = None,
excluded_vars: Union[Sequence[bm.Variable], Dict[str, bm.Variable]] = None,

# deprecated
f_loss_batch: Callable = None,
):
super(SlowPointFinder, self).__init__()

# static arguments
if not isinstance(args, tuple):
raise ValueError(f'args must be an instance of tuple, but we got {type(args)}')
self.args = args

# update function
if target_vars is None:
self.target_vars = TensorCollector()
else:
if not isinstance(target_vars, dict):
raise TypeError(f'"target_vars" must be a dict but we got {type(target_vars)}')
self.target_vars = TensorCollector(target_vars)
excluded_vars = () if excluded_vars is None else excluded_vars
if isinstance(excluded_vars, dict):
excluded_vars = tuple(excluded_vars.values())
if not isinstance(excluded_vars, (tuple, list)):
raise TypeError(f'"excluded_vars" must be a sequence but we got {type(excluded_vars)}')
for v in excluded_vars:
if not isinstance(v, bm.Variable):
raise TypeError(f'"excluded_vars" must be a sequence of Variable, '
f'but we got {type(v)}')
self.excluded_vars = {f'_exclude_v{i}': v for i, v in enumerate(excluded_vars)}
if len(self.target_vars) > 0 and len(self.excluded_vars) > 0:
raise ValueError('"target_vars" and "excluded_vars" cannot be provided simultaneously.')
self.target = f_cell

if isinstance(f_cell, DynamicalSystem):
# included variables
all_vars = f_cell.vars(method='relative', level=-1, include_self=True).unique()

# exclude variables
if len(self.target_vars) > 0:
_all_ids = [id(v) for v in self.target_vars.values()]
for k, v in all_vars.items():
if id(v) not in _all_ids:
self.excluded_vars[k] = v
else:
self.target_vars = all_vars
if len(excluded_vars):
excluded_vars = [id(v) for v in excluded_vars]
for key, val in tuple(self.target_vars.items()):
if id(val) in excluded_vars:
self.target_vars.pop(key)

# input function
if inputs is not None:
inputs = check_and_format_inputs(host=self.target, inputs=inputs)
_input_step, _has_iter = build_inputs(inputs, fun_inputs)
if _has_iter:
raise UnsupportedError(f'Do not support iterable inputs when using fixed point finder.')
else:
_input_step = None

# check included variables
for var in self.target_vars.values():
if var.batch_axis is not None:
if var.shape[var.batch_axis] != 1:
raise ValueError(f'Batched variables should has only one batch. '
f'But we got {var.shape[var.batch_axis]}. Maybe '
f'you need to call ".reset_state(batch_size=1)" '
f'for your system.')

# update function
self.f_cell = self._generate_ds_cell_function(self.target, t, dt, _input_step)

# check function type
if f_type is not None:
if f_type != constants.DISCRETE:
raise ValueError(f'"f_type" must be "{constants.DISCRETE}" when "f_cell" '
f'is instance of {DynamicalSystem.__name__}')
f_type = constants.DISCRETE

# original data
self.target_data = {k: v.value for k, v in self.target_vars.items()}
self.excluded_data = {k: v.value for k, v in self.excluded_vars.items()}

elif callable(f_cell):
if len(self.args) > 0:
self.f_cell = lambda x: f_cell(x, *self.args)
else:
self.f_cell = f_cell
if inputs is not None:
raise UnsupportedError('Do not support "inputs" when "f_cell" is not instance of '
f'{DynamicalSystem.__name__}')
if t is not None:
raise UnsupportedError('Do not support "t" when "f_cell" is not instance of '
f'{DynamicalSystem.__name__}')
if dt is not None:
raise UnsupportedError('Do not support "dt" when "f_cell" is not instance of '
f'{DynamicalSystem.__name__}')
if target_vars is not None:
raise UnsupportedError('Do not support "target_vars" when "f_cell" is not instance of '
f'{DynamicalSystem.__name__}')
if len(excluded_vars) > 0:
raise UnsupportedError('Do not support "excluded_vars" when "f_cell" is not instance of '
f'{DynamicalSystem.__name__}')
else:
self.f_loss_batch = f_loss_batch
self.f_loss = bm.jit(lambda h: bm.mean(f_cell(h) ** 2))
self.f_jacob_batch = bm.jit(vmap(bm.jacobian(f_cell)))
raise ValueError(f'Unknown type of "f_type": {type(f_cell)}')
if f_type not in [constants.DISCRETE, constants.CONTINUOUS]:
raise AnalyzerError(f'Only support "{constants.CONTINUOUS}" (continuous derivative function) or '
f'"{constants.DISCRETE}" (discrete update function), not {f_type}.')
self.verbose = verbose
self.f_type = f_type

# loss functon
if f_loss_batch is not None:
raise UnsupportedError('"f_loss_batch" is no longer supported, please '
'use "f_loss" instead.')
if f_loss is None:
f_loss = losses.mean_squared_error if f_type == constants.DISCRETE else losses.mean_square
self.f_loss = f_loss

# essential variables
self._losses = None
self._fixed_points = None
self._selected_ids = None
self.opt_losses = None
self._opt_losses = None

# functions
self._opt_functions = dict()

@property
def opt_losses(self) -> np.ndarray:
"""The optimization losses."""
return np.asarray(self._opt_losses)

@opt_losses.setter
def opt_losses(self, val):
raise UnsupportedError('Do not support set "opt_losses" by users.')

@property
def fixed_points(self):
def fixed_points(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
"""The final fixed points found."""
return self._fixed_points
return tree_map(lambda a: np.asarray(a), self._fixed_points)

@fixed_points.setter
def fixed_points(self, val):
raise UnsupportedError('Do not support set "fixed_points" by users.')

@property
def losses(self):
def num_fps(self) -> int:
if isinstance(self._fixed_points, dict):
return tuple(self._fixed_points.values())[0].shape[0]
else:
return self._fixed_points.shape[0]

@property
def losses(self) -> np.ndarray:
"""Losses of fixed points."""
return self._losses
return np.asarray(self._losses)

@losses.setter
def losses(self, val):
raise UnsupportedError('Do not support set "losses" by users.')

@property
def selected_ids(self):
def selected_ids(self) -> np.ndarray:
"""The selected ids of candidate points."""
return self._selected_ids

def find_fps_with_gd_method(self,
candidates,
tolerance=1e-5,
num_batch=100,
num_opt=10000,
optimizer=None,
opt_setting=None):
return np.asarray(self._selected_ids)

@selected_ids.setter
def selected_ids(self, val):
raise UnsupportedError('Do not support set "selected_ids" by users.')

def find_fps_with_gd_method(
self,
candidates: Union[Array, Dict[str, Array]],
tolerance: Union[float, Dict[str, float]] = 1e-5,
num_batch: int = 100,
num_opt: int = 10000,
optimizer: optim.Optimizer = None,
):
"""Optimize fixed points with gradient descent methods.

Parameters
----------
candidates : jax.ndarray, JaxArray
candidates : Array, dict
The array with the shape of (batch size, state dim) of hidden states
of RNN to start training for fixed points.

tolerance: float
The loss threshold during optimization

num_opt : int
The maximum number of optimization.

num_batch : int
Print training information during optimization every so often.
opt_setting: optional, dict
The optimization settings.

.. deprecated:: 2.1.2
Use "optimizer" to set optimization method instead.

optimizer: optim.Optimizer
The optimizer instance.

.. versionadded:: 2.1.2
"""

# optimization settings
if opt_setting is None:
if optimizer is None:
optimizer = optim.Adam(lr=optim.ExponentialDecay(0.2, 1, 0.9999),
beta1=0.9, beta2=0.999, eps=1e-8)
else:
assert isinstance(optimizer, optim.Optimizer), (f'Must be an instance of '
f'{optim.Optimizer.__name__}, '
f'while we got {type(optimizer)}')
if optimizer is None:
optimizer = optim.Adam(lr=optim.ExponentialDecay(0.2, 1, 0.9999),
beta1=0.9, beta2=0.999, eps=1e-8)
else:
warnings.warn('Please use "optimizer" to set optimization method. '
'"opt_setting" is deprecated since version 2.1.2. ',
DeprecationWarning)

assert isinstance(opt_setting, dict)
assert 'method' in opt_setting
assert 'lr' in opt_setting
opt_method = opt_setting.pop('method')
if isinstance(opt_method, str):
assert opt_method in optim.__dict__
opt_method = getattr(optim, opt_method)
assert issubclass(opt_method, optim.Optimizer)
opt_lr = opt_setting.pop('lr')
assert isinstance(opt_lr, (int, float, optim.Scheduler))
opt_setting = opt_setting
optimizer = opt_method(lr=opt_lr, **opt_setting)

if self.verbose:
print(f"Optimizing with {optimizer} to find fixed points:")
if not isinstance(optimizer, optim.Optimizer):
raise ValueError(f'Must be an instance of {optim.Optimizer.__name__}, '
f'while we got {type(optimizer)}')

# set up optimization
fixed_points = bm.Variable(bm.asarray(candidates))
grad_f = bm.grad(lambda: self.f_loss_batch(fixed_points.value).mean(),
grad_vars={'a': fixed_points}, return_value=True)
optimizer.register_vars({'a': fixed_points})
dyn_vars = optimizer.vars() + {'_a': fixed_points}
num_candidate = self._check_candidates(candidates)
if not (isinstance(candidates, (bm.ndarray, jnp.ndarray, np.ndarray)) or isinstance(candidates, dict)):
raise ValueError('Candidates must be instance of JaxArray or dict of JaxArray.')
fixed_points = tree_map(lambda a: bm.TrainVar(a), candidates, is_leaf=lambda x: isinstance(x, bm.JaxArray))
f_eval_loss = self._get_f_eval_loss()

def f_loss():
return f_eval_loss(tree_map(lambda a: bm.as_device_array(a),
fixed_points,
is_leaf=lambda x: isinstance(x, bm.JaxArray))).mean()

grad_f = bm.grad(f_loss, grad_vars=fixed_points, return_value=True)
optimizer.register_vars(fixed_points if isinstance(fixed_points, dict) else {'a': fixed_points})
dyn_vars = optimizer.vars() + (fixed_points if isinstance(fixed_points, dict) else {'a': fixed_points})
dyn_vars = dyn_vars.unique()

def train(idx):
gradients, loss = grad_f()
optimizer.update(gradients)
optimizer.update(gradients if isinstance(gradients, dict) else {'a': gradients})
return loss

@partial(bm.jit, dyn_vars=dyn_vars, static_argnames=('start_i', 'num_batch'))
def batch_train(start_i, num_batch):
def batch_train(start_i, n_batch):
f = bm.make_loop(train, dyn_vars=dyn_vars, has_return=True)
return f(bm.arange(start_i, start_i + num_batch))
return f(bm.arange(start_i, start_i + n_batch))

# Run the optimization
if self.verbose:
print(f"Optimizing with {optimizer} to find fixed points:")
opt_losses = []
do_stop = False
num_opt_loops = int(num_opt / num_batch)
for oidx in range(num_opt_loops):
if do_stop: break
if do_stop:
break
batch_idx_start = oidx * num_batch
start_time = time.time()
(_, losses) = batch_train(start_i=batch_idx_start, num_batch=num_batch)
(_, train_losses) = batch_train(start_i=batch_idx_start, n_batch=num_batch)
batch_time = time.time() - start_time
opt_losses.append(losses)
opt_losses.append(train_losses)

if self.verbose:
print(f" "
f"Batches {batch_idx_start + 1}-{batch_idx_start + num_batch} "
f"in {batch_time:0.2f} sec, Training loss {losses[-1]:0.10f}")
f"in {batch_time:0.2f} sec, Training loss {train_losses[-1]:0.10f}")

if losses[-1] < tolerance:
if train_losses[-1] < tolerance:
do_stop = True
if self.verbose:
print(f' '
f'Stop optimization as mean training loss {losses[-1]:0.10f} '
f'Stop optimization as mean training loss {train_losses[-1]:0.10f} '
f'is below tolerance {tolerance:0.10f}.')
self.opt_losses = bm.concatenate(opt_losses)
self._losses = np.asarray(self.f_loss_batch(fixed_points))
self._fixed_points = np.asarray(fixed_points)
self._selected_ids = np.arange(fixed_points.shape[0])

def find_fps_with_opt_solver(self, candidates, opt_method=None):
self._opt_losses = bm.concatenate(opt_losses)
self._losses = f_eval_loss(tree_map(lambda a: bm.as_device_array(a),
fixed_points,
is_leaf=lambda x: isinstance(x, bm.JaxArray)))
self._fixed_points = tree_map(lambda a: bm.as_device_array(a),
fixed_points,
is_leaf=lambda x: isinstance(x, bm.JaxArray))
self._selected_ids = jnp.arange(num_candidate)

if isinstance(self.target, DynamicalSystem):
for k, v in self.excluded_vars.items():
v.value = self.excluded_data[k]
for k, v in self.target_vars.items():
v.value = self.target_data[k]

def find_fps_with_opt_solver(
self,
candidates: Union[Array, Dict[str, Array]],
opt_solver: str = 'BFGS'
):
"""Optimize fixed points with nonlinear optimization solvers.

Parameters
----------
candidates
opt_method: function, callable
candidates: Array, dict
The candidate (initial) fixed points.
opt_solver: str
The solver of the optimization.
"""
# optimization function
num_candidate = self._check_candidates(candidates)
for var in self.target_vars.values():
if bm.ndim(var) != 1:
raise ValueError('Cannot use opt solver.')
if self._opt_functions.get(F_OPT_SOLVER, None) is None:
self._opt_functions[F_OPT_SOLVER] = self._get_f_for_opt_solver(candidates, SUPPORTED_OPT_SOLVERS[opt_solver])
f_opt = self._opt_functions[F_OPT_SOLVER]

assert bm.ndim(candidates) == 2 and isinstance(candidates, (bm.JaxArray, jax.numpy.ndarray))
if opt_method is None:
opt_method = lambda f, x0: minimize(f, x0, method='BFGS')
if self.verbose:
print(f"Optimizing to find fixed points:")
f_opt = bm.jit(vmap(lambda x0: opt_method(self.f_loss, x0)))
res = f_opt(bm.as_device_array(candidates))
valid_ids = jax.numpy.where(res.success)[0]
self._fixed_points = np.asarray(res.x[valid_ids])
self._losses = np.asarray(res.fun[valid_ids])
self._selected_ids = np.asarray(valid_ids)
print(f"Optimizing with {opt_solver} to find fixed points:")

# optimizing
res = f_opt(tree_map(lambda a: bm.as_device_array(a),
candidates,
is_leaf=lambda a: isinstance(a, bm.JaxArray)))

# results
valid_ids = jnp.where(res.success)[0]
fixed_points = res.x[valid_ids]
if isinstance(candidates, dict):
indices = [0]
for v in candidates.values():
indices.append(v.shape[1])
indices = np.cumsum(indices)
keys = tuple(candidates.keys())
self._fixed_points = {key: fixed_points[:, indices[i]: indices[i + 1]]
for i, key in enumerate(keys)}
else:
self._fixed_points = fixed_points
self._losses = res.fun[valid_ids]
self._selected_ids = jnp.asarray(valid_ids)
if self.verbose:
print(f' '
f'Found {len(valid_ids)} fixed points from {len(candidates)} initial points.')
f'Found {len(valid_ids)} fixed points from {num_candidate} initial points.')

def filter_loss(self, tolerance=1e-5):
def filter_loss(self, tolerance: float = 1e-5):
"""Filter fixed points whose speed larger than a given tolerance.

Parameters
@@ -230,18 +461,21 @@ class SlowPointFinder(object):
if self.verbose:
print(f"Excluding fixed points with squared speed above "
f"tolerance {tolerance}:")
num_fps = self.fixed_points.shape[0]
if isinstance(self._fixed_points, dict):
num_fps = tuple(self._fixed_points.values())[0].shape[0]
else:
num_fps = self._fixed_points.shape[0]
ids = self._losses < tolerance
keep_ids = bm.where(ids)[0]
self._fixed_points = self._fixed_points[ids]
keep_ids = bm.as_device_array(bm.where(ids)[0])
self._fixed_points = tree_map(lambda a: a[keep_ids], self._fixed_points)
self._losses = self._losses[keep_ids]
self._selected_ids = self._selected_ids[keep_ids]
if self.verbose:
print(f" "
f"Kept {self._fixed_points.shape[0]}/{num_fps} "
f"Kept {len(keep_ids)}/{num_fps} "
f"fixed points with tolerance under {tolerance}.")

def keep_unique(self, tolerance=2.5e-2):
def keep_unique(self, tolerance: float = 2.5e-2):
"""Filter unique fixed points by choosing a representative within tolerance.

Parameters
@@ -251,16 +485,19 @@ class SlowPointFinder(object):
"""
if self.verbose:
print("Excluding non-unique fixed points:")
num_fps = self.fixed_points.shape[0]
if isinstance(self._fixed_points, dict):
num_fps = tuple(self._fixed_points.values())[0].shape[0]
else:
num_fps = self._fixed_points.shape[0]
fps, keep_ids = utils.keep_unique(self.fixed_points, tolerance=tolerance)
self._fixed_points = fps
self._fixed_points = tree_map(lambda a: jnp.asarray(a), fps)
self._losses = self._losses[keep_ids]
self._selected_ids = self._selected_ids[keep_ids]
if self.verbose:
print(f" Kept {self._fixed_points.shape[0]}/{num_fps} unique fixed points "
print(f" Kept {keep_ids.shape[0]}/{num_fps} unique fixed points "
f"with uniqueness tolerance {tolerance}.")

def exclude_outliers(self, tolerance=1e0):
def exclude_outliers(self, tolerance: float = 1e0):
"""Exclude points whose closest neighbor is further than threshold.

Parameters
@@ -272,11 +509,15 @@ class SlowPointFinder(object):
print("Excluding outliers:")
if np.isinf(tolerance):
return
if self._fixed_points.shape[0] <= 1:
if isinstance(self._fixed_points, dict):
num_fps = tuple(self._fixed_points.values())[0].shape[0]
else:
num_fps = self._fixed_points.shape[0]
if num_fps <= 1:
return

# Compute pairwise distances between all fixed points.
distances = utils.euclidean_distance(self._fixed_points)
distances = np.asarray(utils.euclidean_distance_jax(self.fixed_points, num_fps))

# Find second smallest element in each column of the pairwise distance matrix.
# This corresponds to the closest neighbor for each fixed point.
@@ -284,8 +525,7 @@ class SlowPointFinder(object):

# Return data with outliers removed and indices of kept datapoints.
keep_ids = np.where(closest_neighbor < tolerance)[0]
num_fps = self._fixed_points.shape[0]
self._fixed_points = self._fixed_points[keep_ids]
self._fixed_points = tree_map(lambda a: a[keep_ids], self._fixed_points)
self._selected_ids = self._selected_ids[keep_ids]
self._losses = self._losses[keep_ids]

@@ -294,32 +534,83 @@ class SlowPointFinder(object):
f"Kept {keep_ids.shape[0]}/{num_fps} fixed points "
f"with within outlier tolerance {tolerance}.")

def compute_jacobians(self, points):
"""Compute the jacobian matrices at the points.
def compute_jacobians(
self,
points: Union[Array, Dict[str, Array]],
stack_dict_var: bool = True,
plot: bool = False,
num_col: int = 4,
len_col: int = 3,
len_row: int = 2,
):
"""Compute the Jacobian matrices at the points.

Parameters
----------
points: np.ndarray, bm.JaxArray, jax.ndarray
The fixed points with the shape of (num_point, num_dim).

Returns
-------
jacobians : bm.JaxArray
npoints number of jacobians, np array with shape npoints x dim x dim
stack_dict_var: bool
Stack dictionary variables to calculate Jacobian matrix?
plot: bool
Plot the decomposition results of the Jacobian matrix.
num_col: int
The number of the figure column.
len_col: int
The length of each column.
len_row: int
The length of each row.
"""
# if len(self.fixed_points) == 0: return
if bm.ndim(points) == 1:
points = bm.asarray([points, ])
assert bm.ndim(points) == 2
return self.f_jacob_batch(bm.asarray(points))

def decompose_eigenvalues(self, matrices, sort_by='magnitude', do_compute_lefts=True):
# check data
info = np.asarray([(l.ndim, l.shape[0])
for l in tree_flatten(points, is_leaf=lambda a: isinstance(a, bm.JaxArray))[0]])
ndim = np.unique(info[:, 0])
if len(ndim) != 1: raise ValueError(f'Get multiple dimension of the evaluated points. {ndim}')
if ndim[0] == 1:
points = tree_map(lambda a: bm.asarray([a]), points)
num_point = 1
elif ndim[0] == 2:
nsize = np.unique(info[:, 1])
if len(nsize) != 1: raise ValueError(f'Number of the evaluated points are mis-matched. {nsize}')
num_point = nsize[0]
else:
raise ValueError('Only support points of 1D: (num_feature,) or 2D: (num_point, num_feature)')
if isinstance(points, dict) and stack_dict_var:
points = bm.hstack(points.values()).value

# get Jacobian matrix
jacobian = self._get_f_jocabian(stack_dict_var)(points)

# visualization
if plot:
import matplotlib.pyplot as plt
from brainpy.visualization import visualize
jacobian = bm.as_numpy(jacobian)

num_col = min(num_col, num_point)
num_row = int(math.ceil(num_point / num_col))
fig, gs = visualize.get_figure(num_row, num_col, len_row, len_col)
for i in range(num_point):
eigval, eigvec = np.linalg.eig(np.asarray(jacobian[i]))
ax = fig.add_subplot(gs[i // num_col, i % num_col])
ax.scatter(np.real(eigval), np.imag(eigval))
ax.plot([1, 1] if self.f_type == constants.DISCRETE else [0, 0], [-1, 1], '--')
ax.set_xlabel('Real')
ax.set_ylabel('Imaginary')
ax.set_title(f'Point {i}')
plt.show()

return jacobian

@staticmethod
def decompose_eigenvalues(matrices, sort_by='magnitude', do_compute_lefts=False):
"""Compute the eigenvalues of the matrices.

Parameters
----------
matrices: np.ndarray, bm.JaxArray, jax.ndarray
A 3D array with the shape of (num_matrices, dim, dim).
sort_by: str
The method of sorting.
do_compute_lefts: bool
Compute the left eigenvectors? Requires a pseudo-inverse call.

@@ -335,6 +626,7 @@ class SlowPointFinder(object):
sort_fun = np.real
else:
raise ValueError("Not implemented yet.")
matrices = np.asarray(matrices)

decompositions = []
for mat in matrices:
@@ -348,3 +640,193 @@ class SlowPointFinder(object):
'R': eig_vectors[:, indices],
'L': L})
return decompositions

def _get_f_eval_loss(self, ):
name = 'f_eval_loss'
if name not in self._opt_functions:
self._opt_functions[name] = self._generate_f_eval_loss()
return self._opt_functions[name]

def _generate_f_eval_loss(self):
# evaluate losses of a batch of inputs
if self.f_type == constants.DISCRETE:
f_eval_loss = lambda h: self.f_loss(h, vmap(self.f_cell)(h), axis=1)
else:
f_eval_loss = lambda h: self.f_loss(vmap(self.f_cell)(h), axis=1)

if isinstance(self.target, DynamicalSystem):
@bm.jit
def loss_func(h):
r = f_eval_loss(h)
for k, v in self.excluded_vars.items():
v.value = self.excluded_data[k]
for k, v in self.target_vars.items():
v.value = self.target_data[k]
return r

return loss_func
else:
return bm.jit(f_eval_loss)

def _get_f_for_opt_solver(self, candidates, opt_method):
# loss function
if self.f_type == constants.DISCRETE:
# overall loss function for fixed points optimization
if isinstance(candidates, dict):
keys = tuple(self.target_vars.keys())
indices = [0]
for v in self.target_vars.values():
indices.append(v.shape[0])
indices = np.cumsum(indices)

def f_loss(h):
h = {key: h[indices[i]: indices[i + 1]] for i, key in enumerate(keys)}
return bm.as_device_array(self.f_loss(h, self.f_cell(h)))
else:
def f_loss(h):
return bm.as_device_array(self.f_loss(h, self.f_cell(h)))
else:
# overall loss function for fixed points optimization
def f_loss(h):
return self.f_loss(self.f_cell(h))

@bm.jit
@vmap
def f_opt(x0):
for k, v in self.target_vars.items():
v.value = x0[k] if v.batch_axis is None else bm.expand_dims(x0[k], axis=v.batch_axis)
for k, v in self.excluded_vars.items():
v.value = self.excluded_data[k]
if isinstance(x0, dict):
x0 = bm.concatenate(tuple(x0.values())).value
return opt_method(f_loss, x0)

def call_opt(x):
r = f_opt(x)
for k, v in self.excluded_vars.items():
v.value = self.excluded_data[k]
for k, v in self.target_vars.items():
v.value = self.target_data[k]
return r

return call_opt if isinstance(self.target, DynamicalSystem) else f_opt

def _generate_ds_cell_function(
self, target,
t: float = None,
dt: float = None,
f_input: Callable = None
):
if dt is None: dt = bm.get_dt()
if t is None: t = 0.
shared = DotDict(t=t, dt=dt, i=0)

def f_cell(h: Dict):
target.clear_input()

# update target variables
for k, v in self.target_vars.items():
v.value = (bm.asarray(h[k], dtype=v.dtype)
if v.batch_axis is None else
bm.asarray(bm.expand_dims(h[k], axis=v.batch_axis), dtype=v.dtype))

# update excluded variables
for k, v in self.excluded_vars.items():
v.value = self.excluded_data[k]

# add inputs
if f_input is not None:
f_input(shared)

# call update functions
args = (shared,) + self.args
target.update(*args)

# get new states
new_h = {k: (v.value if v.batch_axis is None else jnp.squeeze(v.value, axis=v.batch_axis))
for k, v in self.target_vars.items()}
return new_h

return f_cell

def _get_f_jocabian(self, stack=True):
name = f'f_eval_jacobian_stack={stack}'
if name not in self._opt_functions:
self._opt_functions[name] = self._generate_ds_jocabian(stack)
return self._opt_functions[name]

def _generate_ds_jocabian(self, stack=True):
if stack and isinstance(self.target, DynamicalSystem):
indices = [0]
for var in self.target_vars.values():
shape = list(var.shape)
if var.batch_axis is not None:
shape.pop(var.batch_axis)
indices.append(np.prod(shape))
indices = np.cumsum(indices)

def jacob(x0):
x0 = {k: x0[indices[i]:indices[i + 1]] for i, k in enumerate(self.target_vars.keys())}
r = self.f_cell(x0)
return bm.concatenate(list(r.values()))
else:
jacob = self.f_cell

f_jac = bm.jit(vmap(bm.jacobian(jacob)))

if isinstance(self.target, DynamicalSystem):
def jacobian_func(x):
r = f_jac(x)
for k, v in self.excluded_vars.items():
v.value = self.excluded_data[k]
for k, v in self.target_vars.items():
v.value = self.target_data[k]
return r

return jacobian_func
else:
return f_jac

def _check_candidates(self, candidates):
if isinstance(self.target, DynamicalSystem):
if not isinstance(candidates, dict):
raise ValueError(f'When "f_cell" is instance of {DynamicalSystem.__name__}, '
f'we should provide "candidates" as a dict, in which the key is '
f'the variable name with relative path, and the value '
f'is the candidate fixed point values. ')
for key in candidates:
if key not in self.target_vars:
raise KeyError(f'"{key}" is not defined in required variables '
f'for fixed point optimization of {self.target}. '
f'Please do not provide its initial values.')

for key in self.target_vars.keys():
if key not in candidates:
raise KeyError(f'"{key}" is defined in required variables '
f'for fixed point optimization of {self.target}. '
f'Please provide its initial values.')
for key, value in candidates.items():
if self.target_vars[key].batch_axis is None:
if value.ndim != self.target_vars[key].ndim + 1:
raise ValueError(f'"{key}" is defined in the required variables for fixed '
f'point optimization of {self.target}. \n'
f'We expect the provided candidate has a batch size, '
f'but we got {value.shape} for variable with shape of '
f'{self.target_vars[key].shape}')
else:
if value.ndim != self.target_vars[key].ndim:
raise ValueError(f'"{key}" is defined in the required variables for fixed '
f'point optimization of {self.target}. \n'
f'We expect the provided candidate has a batch size, '
f'but we got {value.shape} for variable with shape of '
f'{self.target_vars[key].shape}')

if isinstance(candidates, dict):
num_candidate = np.unique([leaf.shape[0] for leaf in candidates.values()])
if len(num_candidate) != 1:
raise ValueError('The numbers of candidates for each variable should be the same. '
f'But we got {num_candidate}')
num_candidate = num_candidate[0]
else:
num_candidate = candidates.shape[0]
return num_candidate

+ 165
- 0
brainpy/analysis/highdim/tests/test_slow_points.py View File

@@ -0,0 +1,165 @@
# -*- coding: utf-8 -*-

import brainpy as bp
import unittest
import brainpy.math as bm


class HH(bp.dyn.NeuGroup):
def __init__(self, size, ENa=50., gNa=120., EK=-77., gK=36., EL=-54.387, gL=0.03,
V_th=20., C=1.0, name=None):
super(HH, self).__init__(size=size, name=name)

# parameters
self.ENa = ENa
self.EK = EK
self.EL = EL
self.C = C
self.gNa = gNa
self.gK = gK
self.gL = gL
self.V_th = V_th

# variables
self.V = bm.Variable(bm.ones(self.num) * -65.)
self.m = bm.Variable(0.5 * bm.ones(self.num))
self.h = bm.Variable(0.6 * bm.ones(self.num))
self.n = bm.Variable(0.32 * bm.ones(self.num))
self.spike = bm.Variable(bm.zeros(size, dtype=bool))
self.input = bm.Variable(bm.zeros(size))

# integral functions
self.int_h = bp.ode.ExponentialEuler(self.dh)
self.int_n = bp.ode.ExponentialEuler(self.dn)
self.int_m = bp.ode.ExponentialEuler(self.dm)
self.int_V = bp.ode.ExponentialEuler(self.dV)

def dh(self, h, t, V):
alpha = 0.07 * bm.exp(-(V + 65) / 20.)
beta = 1 / (1 + bm.exp(-(V + 35) / 10))
dhdt = alpha * (1 - h) - beta * h
return dhdt

def dn(self, n, t, V):
alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
beta = 0.125 * bm.exp(-(V + 65) / 80)
dndt = alpha * (1 - n) - beta * n
return dndt

def dm(self, m, t, V):
alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
beta = 4.0 * bm.exp(-(V + 65) / 18)
dmdt = alpha * (1 - m) - beta * m
return dmdt

def dV(self, V, t, m, h, n, Iext):
I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa)
I_K = (self.gK * n ** 4.0) * (V - self.EK)
I_leak = self.gL * (V - self.EL)
dVdt = (- I_Na - I_K - I_leak + Iext) / self.C
return dVdt

def update(self, tdi):
t, dt = tdi.t, tdi.dt
m = self.int_m(self.m, t, self.V, dt=dt)
h = self.int_h(self.h, t, self.V, dt=dt)
n = self.int_n(self.n, t, self.V, dt=dt)
V = self.int_V(self.V, t, self.m, self.h, self.n, self.input, dt=dt)
self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
self.V.value = V
self.h.value = h
self.n.value = n
self.m.value = m
self.input[:] = 0.


class TestFixedPointsFinding(unittest.TestCase):
def test_opt_solver_for_func1(self):
gamma = 0.641 # Saturation factor for gating variable
tau = 0.06 # Synaptic time constant [sec]
a = 270.
b = 108.
d = 0.154

JE = 0.3725 # self-coupling strength [nA]
JI = -0.1137 # cross-coupling strength [nA]
JAext = 0.00117 # Stimulus input strength [nA]

mu = 20. # Stimulus firing rate [spikes/sec]
coh = 0.5 # Stimulus coherence [%]
Ib1 = 0.3297
Ib2 = 0.3297

def ds1(s1, t, s2, coh=0.5, mu=20.):
I1 = JE * s1 + JI * s2 + Ib1 + JAext * mu * (1. + coh)
r1 = (a * I1 - b) / (1. - bm.exp(-d * (a * I1 - b)))
return - s1 / tau + (1. - s1) * gamma * r1

def ds2(s2, t, s1, coh=0.5, mu=20.):
I2 = JE * s2 + JI * s1 + Ib2 + JAext * mu * (1. - coh)
r2 = (a * I2 - b) / (1. - bm.exp(-d * (a * I2 - b)))
return - s2 / tau + (1. - s2) * gamma * r2

def step(s):
return bm.asarray([ds1(s[0], 0., s[1]), ds2(s[1], 0., s[0])])

finder = bp.analysis.SlowPointFinder(f_cell=step, f_type=bp.analysis.CONTINUOUS)
finder.find_fps_with_opt_solver(bm.random.random((100, 2)))

def test_opt_solver_for_ds1(self):
hh = HH(1)
finder = bp.analysis.SlowPointFinder(f_cell=hh, excluded_vars=[hh.input, hh.spike])

with self.assertRaises(ValueError):
finder.find_fps_with_opt_solver(bm.random.random((100, 4)))

finder.find_fps_with_opt_solver({'V': bm.random.random((100, 1)),
'm': bm.random.random((100, 1)),
'h': bm.random.random((100, 1)),
'n': bm.random.random((100, 1))})

def test_gd_method_for_func1(self):
gamma = 0.641 # Saturation factor for gating variable
tau = 0.06 # Synaptic time constant [sec]
a = 270.
b = 108.
d = 0.154

JE = 0.3725 # self-coupling strength [nA]
JI = -0.1137 # cross-coupling strength [nA]
JAext = 0.00117 # Stimulus input strength [nA]

mu = 20. # Stimulus firing rate [spikes/sec]
coh = 0.5 # Stimulus coherence [%]
Ib1 = 0.3297
Ib2 = 0.3297

def ds1(s1, t, s2, coh=0.5, mu=20.):
I1 = JE * s1 + JI * s2 + Ib1 + JAext * mu * (1. + coh)
r1 = (a * I1 - b) / (1. - bm.exp(-d * (a * I1 - b)))
return - s1 / tau + (1. - s1) * gamma * r1

def ds2(s2, t, s1, coh=0.5, mu=20.):
I2 = JE * s2 + JI * s1 + Ib2 + JAext * mu * (1. - coh)
r2 = (a * I2 - b) / (1. - bm.exp(-d * (a * I2 - b)))
return - s2 / tau + (1. - s2) * gamma * r2

def step(s):
return bm.asarray([ds1(s[0], 0., s[1]), ds2(s[1], 0., s[0])])

finder = bp.analysis.SlowPointFinder(f_cell=step, f_type=bp.analysis.CONTINUOUS)
finder.find_fps_with_gd_method(bm.random.random((100, 2)), num_opt=100)

def test_gd_method_for_func2(self):
hh = HH(1)
finder = bp.analysis.SlowPointFinder(f_cell=hh, excluded_vars=[hh.input, hh.spike])

with self.assertRaises(ValueError):
finder.find_fps_with_opt_solver(bm.random.random((100, 4)))

finder.find_fps_with_gd_method({'V': bm.random.random((100, 1)),
'm': bm.random.random((100, 1)),
'h': bm.random.random((100, 1)),
'n': bm.random.random((100, 1))},
num_opt=100)


+ 34
- 24
brainpy/analysis/lowdim/lowdim_analyzer.py View File

@@ -1,15 +1,17 @@
# -*- coding: utf-8 -*-

import warnings
from functools import partial

import numpy as np
from jax import vmap
from jax import numpy as jnp
from jax import vmap
from jax.scipy.optimize import minimize

import brainpy.math as bm
from brainpy import errors, tools
from brainpy.analysis import constants as C, utils
from brainpy.analysis.base import DSAnalyzer
from brainpy.base.collector import Collector

pyplot = None
@@ -21,7 +23,7 @@ __all__ = [
]


class LowDimAnalyzer(object):
class LowDimAnalyzer(DSAnalyzer):
r"""Automatic Analyzer for Low-dimensional Dynamical Systems.

A dynamical model is characterized by a series of dynamical
@@ -68,16 +70,18 @@ class LowDimAnalyzer(object):
The optional setting. Maybe needed in the individual analyzer.
"""

def __init__(self,
model,
target_vars,
fixed_vars=None,
target_pars=None,
pars_update=None,
resolutions=None,
jit_device=None,
lim_scale=1.05,
options=None, ):
def __init__(
self,
model,
target_vars,
fixed_vars=None,
target_pars=None,
pars_update=None,
resolutions=None,
jit_device=None,
lim_scale=1.05,
options=None,
):
# model
# -----
self.model = utils.model_transform(model)
@@ -152,6 +156,12 @@ class LowDimAnalyzer(object):
for key, lim in self.target_pars.items():
self.resolutions[key] = bm.linspace(*lim, 20)
elif isinstance(resolutions, float):
if len(self.target_pars) >= 1:
warnings.warn('The `resolutions` is specified to all parameters and variables. '
'Analysis computation may occupy too much memory if `resolutions` is small. '
'Please specify `resolutions` for each parameter and variable by dict, '
'such as resolutions={"V": 0.1}.',
category=UserWarning)
for key, lim in self.target_vars.items():
self.resolutions[key] = bm.arange(*lim, resolutions)
for key, lim in self.target_pars.items():
@@ -163,7 +173,7 @@ class LowDimAnalyzer(object):
if key in self.target_par_names:
continue
raise errors.AnalyzerError(f'The resolution setting target "{key}" is not found in '
f'the target variables {self.target_var_names} and '
f'the target variables {self.target_var_names} or '
f'the target parameters {self.target_par_names}.')
for key in self.target_var_names + self.target_par_names:
if key not in resolutions:
@@ -206,7 +216,7 @@ class LowDimAnalyzer(object):
# 'x_by_y_in_fy' :
# 'y_by_x_in_fx' :
# 'x_by_y_in_fx' :
self.analyzed_results = tools.DictPlus()
self.analyzed_results = tools.DotDict()

def show_figure(self):
global pyplot
@@ -251,9 +261,9 @@ class Num1DAnalyzer(LowDimAnalyzer):
>>> self.F_fx(v1, v2, p1, p2)
"""
if C.F_fx not in self.analyzed_results:
_, arguments = utils.get_args(self.model.F[self.x_var])
_, arguments = utils.get_args(self.model.f_derivatives[self.x_var])
wrapper = utils.std_derivative(arguments, self.target_var_names, self.target_par_names)
f = wrapper(self.model.F[self.x_var])
f = wrapper(self.model.f_derivatives[self.x_var])
f = partial(f, **(self.pars_update + self.fixed_vars))
f = utils.f_without_jaxarray_return(f)
f = utils.remove_return_shape(f)
@@ -412,9 +422,9 @@ class Num2DAnalyzer(Num1DAnalyzer):
>>> self.F_fy(v1, v2, p1, p2)
"""
if C.F_fy not in self.analyzed_results:
variables, arguments = utils.get_args(self.model.F[self.y_var])
variables, arguments = utils.get_args(self.model.f_derivatives[self.y_var])
wrapper = utils.std_derivative(arguments, self.target_var_names, self.target_par_names)
f = wrapper(self.model.F[self.y_var])
f = wrapper(self.model.f_derivatives[self.y_var])
f = partial(f, **(self.pars_update + self.fixed_vars))
f = utils.f_without_jaxarray_return(f)
f = utils.remove_return_shape(f)
@@ -424,18 +434,18 @@ class Num2DAnalyzer(Num1DAnalyzer):
@property
def F_int_x(self):
if C.F_int_x not in self.analyzed_results:
wrap_x = utils.std_derivative(utils.get_args(self.model.F[self.x_var])[1],
wrap_x = utils.std_derivative(utils.get_args(self.model.f_derivatives[self.x_var])[1],
self.target_var_names, self.target_par_names)
init_x = partial(wrap_x(self.model.INTG[0]), **(self.pars_update + self.fixed_vars))
init_x = partial(wrap_x(self.model.f_integrals[0]), **(self.pars_update + self.fixed_vars))
self.analyzed_results[C.F_int_x] = init_x
return self.analyzed_results[C.F_int_x]

@property
def F_int_y(self):
if C.F_int_y not in self.analyzed_results:
wrap_x = utils.std_derivative(utils.get_args(self.model.F[self.y_var])[1],
wrap_x = utils.std_derivative(utils.get_args(self.model.f_derivatives[self.y_var])[1],
self.target_var_names, self.target_par_names)
init_x = partial(wrap_x(self.model.INTG[1]), **(self.pars_update + self.fixed_vars))
init_x = partial(wrap_x(self.model.f_integrals[1]), **(self.pars_update + self.fixed_vars))
self.analyzed_results[C.F_int_y] = init_x
return self.analyzed_results[C.F_int_y]

@@ -1021,9 +1031,9 @@ class Num3DAnalyzer(Num2DAnalyzer):
def F_fz(self):
"""The function to evaluate :math:`f_y(*\mathrm{vars}, *\mathrm{pars})`."""
if C.F_fz not in self.analyzed_results:
variables, arguments = utils.get_args(self.model.F[self.z_var])
variables, arguments = utils.get_args(self.model.f_derivatives[self.z_var])
wrapper = utils.std_derivative(arguments, self.target_var_names, self.target_par_names)
f = wrapper(self.model.F[self.z_var])
f = wrapper(self.model.f_derivatives[self.z_var])
f = partial(f, **(self.pars_update + self.fixed_vars))
self.analyzed_results[C.F_fz] = bm.jit(f, device=self.jit_device)
return self.analyzed_results[C.F_fz]


+ 2
- 0
brainpy/analysis/lowdim/tests/test_phase_plane.py View File

@@ -26,6 +26,7 @@ class TestPhasePlane(unittest.TestCase):
analyzer.plot_vector_field()
analyzer.plot_fixed_point()
plt.show(block=block)
plt.close()
bp.math.disable_x64()

def test_2d_decision_making_model(self):
@@ -74,4 +75,5 @@ class TestPhasePlane(unittest.TestCase):
analyzer.plot_nullcline(coords=dict(s2='s2-s1'))
analyzer.plot_fixed_point()
plt.show(block=block)
plt.close()
bp.math.disable_x64()

+ 67
- 10
brainpy/analysis/utils/measurement.py View File

@@ -1,13 +1,20 @@
# -*- coding: utf-8 -*-

from functools import partial
from typing import Union

import jax
import jax.numpy as jnp
import numpy as np
from brainpy.tools.others import numba_jit
from jax.tree_util import tree_flatten

import brainpy.math as bm
from brainpy.tools.others import numba_jit

__all__ = [
'find_indexes_of_limit_cycle_max',
'euclidean_distance',
'euclidean_distance_jax',
]


@@ -31,8 +38,8 @@ def find_indexes_of_limit_cycle_max(arr, tol=0.001):
return _f1(arr, grad, tol)


# @tools.numba_jit
def euclidean_distance(points: np.ndarray):
@numba_jit
def euclidean_distance(points: np.ndarray, num_point=None):
"""Get the distance matrix.

Equivalent to:
@@ -50,13 +57,63 @@ def euclidean_distance(points: np.ndarray):
dist_matrix: jnp.ndarray
The distance matrix.
"""
num_point = points.shape[0]
indices = np.triu_indices(num_point)
dist_mat = np.zeros((num_point, num_point))
for idx in range(len(indices[0])):
i = indices[0][idx]
j = indices[1][idx]
dist_mat[i, j] = np.linalg.norm(points[i] - points[j])

if isinstance(points, dict):
if num_point is None:
raise ValueError('Please provide num_point')
indices = np.triu_indices(num_point)
dist_mat = np.zeros((num_point, num_point))
for idx in range(len(indices[0])):
i = indices[0][idx]
j = indices[1][idx]
dist_mat[i, j] = np.sqrt(np.sum([np.sum((value[i] - value[j]) ** 2) for value in points.values()]))
else:
num_point = points.shape[0]
indices = np.triu_indices(num_point)
dist_mat = np.zeros((num_point, num_point))
for idx in range(len(indices[0])):
i = indices[0][idx]
j = indices[1][idx]
dist_mat[i, j] = np.linalg.norm(points[i] - points[j])
dist_mat = np.maximum(dist_mat, dist_mat.T)
return dist_mat


@jax.jit
@partial(jax.vmap, in_axes=[0, 0, None])
def _ed(i, j, leaves):
squares = bm.asarray([((leaf[i] - leaf[j]) ** 2).sum() for leaf in leaves])
return bm.sqrt(bm.sum(squares))


def euclidean_distance_jax(points: Union[jnp.ndarray, bm.ndarray], num_point=None):
"""Get the distance matrix.

Equivalent to:

>>> from scipy.spatial.distance import squareform, pdist
>>> f = lambda points: squareform(pdist(points, metric="euclidean"))

Parameters
----------
points: jnp.ndarray, bm.JaxArray
The points.
num_point: int

Returns
-------
dist_matrix: JaxArray
The distance matrix.
"""
if isinstance(points, dict):
if num_point is None:
raise ValueError('Please provide num_point')
else:
num_point = points.shape[0]
indices = jnp.triu_indices(num_point)
dist_mat = bm.zeros((num_point, num_point))
leaves, _ = tree_flatten(points)
dist_mat[indices] = _ed(*indices, leaves)
dist_mat = bm.maximum(dist_mat, dist_mat.T)
return dist_mat


+ 51
- 41
brainpy/analysis/utils/model.py View File

@@ -4,11 +4,12 @@
import jax.numpy as jnp

import brainpy.math as bm
from brainpy import errors
from brainpy.dyn.base import DynamicalSystem
from brainpy.dyn.runners import DSRunner
from brainpy.errors import AnalyzerError, UnsupportedError
from brainpy.integrators.base import Integrator
from brainpy.integrators.joint_eq import JointEq
from brainpy.integrators.ode.base import ODEIntegrator
from brainpy.integrators.ode import ODEIntegrator, odeint

__all__ = [
'model_transform',
@@ -17,63 +18,69 @@ __all__ = [
]


def _check_model(model):
if isinstance(model, Integrator):
if not isinstance(model, ODEIntegrator):
raise AnalyzerError(f'Must be the instance of {ODEIntegrator.__name__}, but got {model}.')
elif callable(model):
model = odeint(model)
else:
raise ValueError(f'Please provide derivative function or integral function. But we got {model}')
if isinstance(model.f, JointEq):
return [type(model)(eq, var_type=model.var_type, dt=model.dt) for eq in model.f.eqs]
else:
return [model]


def model_transform(model):
# check integrals
if isinstance(model, NumDSWrapper):
# check model
if isinstance(model, DynamicalSystem):
model = tuple(model.nodes(level=-1).subset(ODEIntegrator).unique().values())
elif isinstance(model, NumDSWrapper):
return model
elif isinstance(model, ODEIntegrator): #
model = [model]

# check model types
elif callable(model):
model = [model]
all_models = []
if isinstance(model, (list, tuple)):
if len(model) == 0:
raise errors.AnalyzerError(f'Found no integrators: {model}')
model = tuple(model)
for intg in model:
if not isinstance(intg, ODEIntegrator):
raise errors.AnalyzerError(f'Must be the instance of {ODEIntegrator}, but got {intg}.')
raise AnalyzerError(f'Found no derivative/integral functions: {model}')
for fun in tuple(model):
all_models.extend(_check_model(fun))
elif isinstance(model, dict):
if len(model) == 0:
raise errors.AnalyzerError(f'Found no integrators: {model}')
model = tuple(model.values())
for intg in model:
if not isinstance(intg, ODEIntegrator):
raise errors.AnalyzerError(f'Must be the instance of {ODEIntegrator}, but got {intg}')
elif isinstance(model, DynamicalSystem):
model = tuple(model.ints().subset(ODEIntegrator).unique().values())
raise AnalyzerError(f'Found no derivative/integral functions: {model}')
for fun in tuple(model.values()):
all_models.extend(_check_model(fun))
else:
raise errors.UnsupportedError(f'Dynamics analysis by symbolic approach only supports '
f'list/tuple/dict of {ODEIntegrator} or {DynamicalSystem}, '
f'but we got: {type(model)}: {str(model)}')

new_model = []
for intg in model:
if isinstance(intg.f, JointEq):
new_model.extend([type(intg)(eq, var_type=intg.var_type, dt=intg.dt) for eq in intg.f.eqs])
else:
new_model.append(intg)
raise UnsupportedError(f'Dynamics analysis by symbolic approach only supports '
f'derivative/integral functions or {DynamicalSystem.__name__}, '
f'but we got: {type(model)}: {str(model)}')

# pars to update
pars_update = set()
for intg in new_model:
pars_update.update(intg.parameters[1:])
for fun in all_models:
pars_update.update(fun.parameters[1:])

# variables and parameters
all_variables = set()
all_parameters = set()
for integral in new_model:
for integral in all_models:
# variable
if len(integral.variables) != 1:
raise errors.AnalyzerError(f'Only supports one {ODEIntegrator.__name__} one variable, '
f'but we got {len(integral.variables)} variables in {integral}.')
raise AnalyzerError(f'Only supports one {ODEIntegrator.__name__} one variable, '
f'but we got {len(integral.variables)} variables in {integral}.')
var = integral.variables[0]
if var in all_variables:
raise errors.AnalyzerError(f'Variable name {var} has been defined before. '
f'Please change another name.')
raise AnalyzerError(f'Variable name {var} has been defined before. '
f'Please change another name.')
all_variables.add(var)
# parameters
# parameter
all_parameters.update(integral.parameters[1:])

# form a dynamic model
return NumDSWrapper(integrals=new_model,
return NumDSWrapper(integrals=all_models,
variables=list(all_variables),
parameters=list(all_parameters),
pars_update=pars_update)
@@ -87,14 +94,17 @@ class NumDSWrapper(object):
variables,
parameters,
pars_update=None):
self.INTG = integrals # all integrators
self.F = {intg.variables[0]: intg.f for intg in integrals} # all integrators
self.f_integrals = integrals # all integrators
self.f_derivatives = {intg.variables[0]: intg.f for intg in integrals} # all integrators
self.variables = variables # all variables
self.parameters = parameters # all parameters
self.pars_update = pars_update # the parameters to update
self.name2integral = {intg.variables[0]: intg for intg in integrals}
self.name2derivative = {intg.variables[0]: intg.f for intg in integrals}

def __repr__(self):
return f'{self.__class__.__name__}(variables={self.variables}, parameters={self.parameters})'


class TrajectModel(DynamicalSystem):
def __init__(self, integrals: dict, initial_vars: dict, pars=None, dt=None):
@@ -121,10 +131,10 @@ class TrajectModel(DynamicalSystem):
dyn_vars=self.vars().unique(), dt=dt,
progress_bar=False)

def update(self, t, dt):
def update(self, sha):
all_vars = list(self.implicit_vars.values())
for key, intg in self.integrals.items():
self.implicit_vars[key].update(intg(*all_vars, *self.pars, dt=dt))
self.implicit_vars[key].update(intg(*all_vars, *self.pars, dt=sha['dt']))

def __getattr__(self, item):
child_vars = super(TrajectModel, self).__getattribute__('implicit_vars')


+ 55
- 6
brainpy/analysis/utils/others.py View File

@@ -1,12 +1,14 @@
# -*- coding: utf-8 -*-

from typing import Union, Dict
import jax.numpy as jnp
from jax import vmap
import numpy as np
from jax.tree_util import tree_flatten, tree_map

import brainpy.math as bm
from .function import f_without_jaxarray_return
from .measurement import euclidean_distance
from .measurement import euclidean_distance, euclidean_distance_jax

__all__ = [
'Segment',
@@ -44,7 +46,7 @@ def check_initials(initials, target_var_names):
assert isinstance(initials, dict)
for p in target_var_names:
assert p in initials
initials = {p: bm.asarray(initials[p], dtype=bm.float_) for p in target_var_names}
initials = {p: bm.asarray(initials[p], dtype=bm.dftype()) for p in target_var_names}
len_of_init = []
for v in initials.values():
assert isinstance(v, (tuple, list, np.ndarray, jnp.ndarray, bm.ndarray))
@@ -85,12 +87,59 @@ def get_sign2(f, *xyz, args=()):
return jnp.sign(f(*(XYZ + args))).reshape(shape)


def keep_unique(candidates, tolerance=2.5e-2):
def keep_unique(candidates: Union[np.ndarray, Dict[str, np.ndarray]],
tolerance: float=2.5e-2):
"""Filter unique fixed points by choosing a representative within tolerance.

Parameters
----------
candidates: np.ndarray
candidates: np.ndarray, dict
The fixed points with the shape of (num_point, num_dim).
tolerance: float
tolerance.

Returns
-------
fps_and_ids : tuple
A 2-tuple of (kept fixed points, ids of kept fixed points).
"""
if isinstance(candidates, dict):
element = tuple(candidates.values())[0]
num_fps = element.shape[0]
dtype = element.dtype
else:
num_fps = candidates.shape[0]
dtype = candidates.dtype
keep_ids = np.arange(num_fps)
if tolerance <= 0.0:
return candidates, keep_ids
if num_fps <= 1:
return candidates, keep_ids
candidates = tree_map(lambda a: np.asarray(a), candidates, is_leaf=lambda a: isinstance(a, bm.JaxArray))

# If point A and point B are within identical_tol of each other, and the
# A is first in the list, we keep A.
distances = np.asarray(euclidean_distance_jax(candidates, num_fps))
example_idxs = np.arange(num_fps)
all_drop_idxs = []
for fidx in range(num_fps - 1):
distances_f = distances[fidx, fidx + 1:]
drop_idxs = example_idxs[fidx + 1:][distances_f <= tolerance]
all_drop_idxs += list(drop_idxs)
keep_ids = np.setdiff1d(example_idxs, np.unique(all_drop_idxs))
if keep_ids.shape[0] > 0:
unique_fps = tree_map(lambda a: a[keep_ids], candidates)
else:
unique_fps = np.array([], dtype=dtype)
return unique_fps, keep_ids


def keep_unique_jax(candidates, tolerance=2.5e-2):
"""Filter unique fixed points by choosing a representative within tolerance.

Parameters
----------
candidates: Tesnor
The fixed points with the shape of (num_point, num_dim).

Returns
@@ -107,14 +156,14 @@ def keep_unique(candidates, tolerance=2.5e-2):
# If point A and point B are within identical_tol of each other, and the
# A is first in the list, we keep A.
nfps = candidates.shape[0]
distances = euclidean_distance(candidates)
distances = euclidean_distance_jax(candidates)
example_idxs = np.arange(nfps)
all_drop_idxs = []
for fidx in range(nfps - 1):
distances_f = distances[fidx, fidx + 1:]
drop_idxs = example_idxs[fidx + 1:][distances_f <= tolerance]
all_drop_idxs += list(drop_idxs)
keep_ids = np.setdiff1d(example_idxs, np.unique(all_drop_idxs))
keep_ids = np.setdiff1d(example_idxs, np.unique(np.asarray(all_drop_idxs)))
if keep_ids.shape[0] > 0:
unique_fps = candidates[keep_ids, :]
else:


+ 79
- 30
brainpy/base/base.py View File

@@ -22,15 +22,15 @@ class Base(object):
The subclass of Base includes:

- ``DynamicalSystem`` in *brainpy.dyn.base.py*
- ``Module`` in *brainpy.dyn.base_module.py*
- ``Integrator`` in *brainpy.integrators.base.py*
- ``Function`` in *brainpy.base.function.py*
- ``AutoGrad`` in *brainpy.math.autograd.py*
- ``Optimizer`` in *brainpy.optimizers.py*
- ``Scheduler`` in *brainpy.optimizers.py*

"""

_excluded_vars = ()

def __init__(self, name=None):
# check whether the object has a unique name.
self._name = None
@@ -47,6 +47,7 @@ class Base(object):

@property
def name(self):
"""Name of the model."""
return self._name

@name.setter
@@ -54,13 +55,48 @@ class Base(object):
self._name = self.unique_name(name=name)
naming.check_name_uniqueness(name=self._name, obj=self)

def register_implicit_vars(self, variables):
assert isinstance(variables, dict), f'Must be a dict, but we got {type(variables)}'
self.implicit_vars.update(variables)

def register_implicit_nodes(self, nodes):
assert isinstance(nodes, dict), f'Must be a dict, but we got {type(nodes)}'
self.implicit_nodes.update(nodes)
def register_implicit_vars(self, *variables, **named_variables):
from brainpy.math import Variable
for variable in variables:
if isinstance(variable, Variable):
self.implicit_vars[f'var{id(variable)}'] = variable
elif isinstance(variable, (tuple, list)):
for v in variable:
if not isinstance(v, Variable):
raise ValueError(f'Must be instance of {Variable.__name__}, but we got {type(v)}')
self.implicit_vars[f'var{id(variable)}'] = v
elif isinstance(variable, dict):
for k, v in variable.items():
if not isinstance(v, Variable):
raise ValueError(f'Must be instance of {Variable.__name__}, but we got {type(v)}')
self.implicit_vars[k] = v
else:
raise ValueError(f'Unknown type: {type(variable)}')
for key, variable in named_variables.items():
if not isinstance(variable, Variable):
raise ValueError(f'Must be instance of {Variable.__name__}, but we got {type(variable)}')
self.implicit_vars[key] = variable

def register_implicit_nodes(self, *nodes, **named_nodes):
for node in nodes:
if isinstance(node, Base):
self.implicit_nodes[node.name] = node
elif isinstance(node, (tuple, list)):
for n in node:
if not isinstance(n, Base):
raise ValueError(f'Must be instance of {Base.__name__}, but we got {type(n)}')
self.implicit_nodes[n.name] = n
elif isinstance(node, dict):
for k, n in node.items():
if not isinstance(n, Base):
raise ValueError(f'Must be instance of {Base.__name__}, but we got {type(n)}')
self.implicit_nodes[k] = n
else:
raise ValueError(f'Unknown type: {type(node)}')
for key, node in named_nodes.items():
if not isinstance(node, Base):
raise ValueError(f'Must be instance of {Base.__name__}, but we got {type(node)}')
self.implicit_nodes[key] = node

def vars(self, method='absolute', level=-1, include_self=True):
"""Collect all variables in this node and the children nodes.
@@ -88,7 +124,8 @@ class Base(object):
for k in dir(node):
v = getattr(node, k)
if isinstance(v, math.Variable):
gather[f'{node_path}.{k}' if node_path else k] = v
if k not in node._excluded_vars:
gather[f'{node_path}.{k}' if node_path else k] = v
gather.update({f'{node_path}.{k}': v for k, v in node.implicit_vars.items()})
return gather

@@ -117,6 +154,13 @@ class Base(object):
if _paths is None:
_paths = set()
gather = Collector()
if include_self:
if method == 'absolute':
gather[self.name] = self
elif method == 'relative':
gather[''] = self
else:
raise ValueError(f'No support for the method of "{method}".')
if (level > -1) and (_lid >= level):
return gather
if method == 'absolute':
@@ -135,13 +179,14 @@ class Base(object):
gather[node.name] = node
nodes.append(node)
for v in nodes:
gather.update(v._find_nodes(method=method, level=level, _lid=_lid + 1, _paths=_paths,
gather.update(v._find_nodes(method=method,
level=level,
_lid=_lid + 1,
_paths=_paths,
include_self=include_self))
if include_self: gather[self.name] = self

elif method == 'relative':
nodes = []
if include_self: gather[''] = self
for k, v in self.__dict__.items():
if isinstance(v, Base):
path = (id(self), id(v))
@@ -156,8 +201,11 @@ class Base(object):
gather[key] = node
nodes.append((key, node))
for k1, v1 in nodes:
for k2, v2 in v1._find_nodes(method=method, _paths=_paths, _lid=_lid + 1,
level=level, include_self=include_self).items():
for k2, v2 in v1._find_nodes(method=method,
_paths=_paths,
_lid=_lid + 1,
level=level,
include_self=include_self).items():
if k2: gather[f'{k1}.{k2}'] = v2

else:
@@ -208,7 +256,7 @@ class Base(object):
naming.check_name_uniqueness(name=name, obj=self)
return name

def load_states(self, filename, verbose=False, check_missing=False):
def load_states(self, filename, verbose=False):
"""Load the model states.

Parameters
@@ -216,41 +264,42 @@ class Base(object):
filename : str
The filename which stores the model states.
verbose: bool
check_missing: bool
Whether report the load progress.
"""
if not os.path.exists(filename):
raise errors.BrainPyError(f'Cannot find the file path: {filename}')
elif filename.endswith('.hdf5') or filename.endswith('.h5'):
io.load_h5(filename, target=self, verbose=verbose, check=check_missing)
io.load_by_h5(filename, target=self, verbose=verbose)
elif filename.endswith('.pkl'):
io.load_pkl(filename, target=self, verbose=verbose, check=check_missing)
io.load_by_pkl(filename, target=self, verbose=verbose)
elif filename.endswith('.npz'):
io.load_npz(filename, target=self, verbose=verbose, check=check_missing)
io.load_by_npz(filename, target=self, verbose=verbose)
elif filename.endswith('.mat'):
io.load_mat(filename, target=self, verbose=verbose, check=check_missing)
io.load_by_mat(filename, target=self, verbose=verbose)
else:
raise errors.BrainPyError(f'Unknown file format: {filename}. We only supports {io.SUPPORTED_FORMATS}')

def save_states(self, filename, all_vars=None, **setting):
def save_states(self, filename, variables=None, **setting):
"""Save the model states.

Parameters
----------
filename : str
The file name which to store the model states.
all_vars: optional, dict, TensorCollector
variables: optional, dict, TensorCollector
The variables to save. If not provided, all variables retrieved by ``~.vars()`` will be used.
"""
if all_vars is None:
all_vars = self.vars(method='relative').unique()
if variables is None:
variables = self.vars(method='absolute', level=-1)

if filename.endswith('.hdf5') or filename.endswith('.h5'):
io.save_h5(filename, all_vars=all_vars)
elif filename.endswith('.pkl'):
io.save_pkl(filename, all_vars=all_vars)
io.save_as_h5(filename, variables=variables)
elif filename.endswith('.pkl') or filename.endswith('.pickle'):
io.save_as_pkl(filename, variables=variables)
elif filename.endswith('.npz'):
io.save_npz(filename, all_vars=all_vars, **setting)
io.save_as_npz(filename, variables=variables, **setting)
elif filename.endswith('.mat'):
io.save_mat(filename, all_vars=all_vars)
io.save_as_mat(filename, variables=variables)
else:
raise errors.BrainPyError(f'Unknown file format: {filename}. We only supports {io.SUPPORTED_FORMATS}')



+ 56
- 51
brainpy/base/collector.py View File

@@ -1,9 +1,7 @@
# -*- coding: utf-8 -*-


import jax
import jax.numpy as jnp
from contextlib import contextmanager
from typing import Dict, Sequence, Union

math = None

@@ -29,31 +27,73 @@ class Collector(dict):
"""Replace the original key with the new value."""
self.pop(key)
self[key] = new_value
# dict.__setitem__(self, key, new_value)

def update(self, other, **kwargs):
assert isinstance(other, dict)
for key, value in other.items():
self[key] = value
assert isinstance(other, (dict, list, tuple))
if isinstance(other, dict):
for key, value in other.items():
self[key] = value
elif isinstance(other, (tuple, list)):
num = len(self)
for i, value in enumerate(other):
self[f'_var{i+num}'] = value
else:
raise ValueError(f'Only supports dict/list/tuple, but we got {type(other)}')
for key, value in kwargs.items():
self[key] = value

def __add__(self, other):
"""Merging two dicts.

Parameters
----------
other: dict
The other dict instance.

Returns
-------
gather: Collector
The new collector.
"""
gather = type(self)(self)
gather.update(other)
return gather

def __sub__(self, other):
def __sub__(self, other: Union[Dict, Sequence]):
"""Remove other item in the collector.

Parameters
----------
other: dict, sequence
The items to remove.

Returns
-------
gather: Collector
The new collector.
"""
if not isinstance(other, dict):
raise ValueError(f'Only support dict, but we got {type(other)}.')
gather = type(self)()
for key, val in self.values():
if key in other:
if id(val) != id(other[key]):
raise ValueError(f'Cannot remove {key}, because we got two different values: '
f'{val} != {other[key]}')
else:
gather[key] = val
gather = type(self)(self)
if isinstance(other, dict):
for key, val in other.items():
if key in gather:
if id(val) != id(gather[key]):
raise ValueError(f'Cannot remove {key}, because we got two different values: '
f'{val} != {gather[key]}')
gather.pop(key)
else:
raise ValueError(f'Cannot remove {key}, because we do not find it '
f'in {self.keys()}.')
elif isinstance(other, (list, tuple)):
for key in other:
if key in gather:
gather.pop(key)
else:
raise ValueError(f'Cannot remove {key}, because we do not find it '
f'in {self.keys()}.')
else:
raise KeyError(f'Unknown type of "other". Only support dict/tuple/list, but we got {type(other)}')
return gather

def subset(self, var_type):
@@ -146,38 +186,3 @@ class TensorCollector(Collector):
def data(self):
"""Get all data in each value."""
return [x.value for x in self.values()]

# @contextmanager
# def replicate(self):
# """A context manager to use in a with statement that replicates
# the variables in this collection to multiple devices.
#
# Important: replicating also updates the random state in order
# to have a new one per device.
# """
# global math
# if math is None: from brainpy import math
#
# replicated, saved_states = {}, {}
# x = jnp.zeros((jax.local_device_count(), 1), dtype=math.float_)
# sharded_x = jax.pmap(lambda x: x, axis_name='device')(x)
# devices = [b.device() for b in sharded_x.device_buffers]
# num_device = len(devices)
# for k, d in self.items():
# if isinstance(d, math.random.RandomState):
# replicated[k] = jax.device_put_sharded([shard for shard in d.split(num_device)], devices)
# saved_states[k] = d.value
# else:
# replicated[k] = jax.device_put_replicated(d.value, devices)
# self.assign(replicated)
# yield
# visited = set()
# for k, d in self.items():
# # Careful not to reduce twice in case of
# # a variable and a reference to it.
# if id(d) not in visited:
# if isinstance(d, math.random.RandomState):
# d.value = saved_states[k]
# else:
# d.value = reduce_func(d)
# visited.add(id(d))

+ 341
- 104
brainpy/base/io.py View File

@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-

from typing import Dict, Type, Union, Tuple, List
import logging
import os
import pickle

import numpy as np
@@ -9,35 +9,47 @@ import numpy as np
from brainpy import errors
from brainpy.base.collector import TensorCollector

Base = math = None
logger = logging.getLogger('brainpy.base.io')

try:
import h5py
except (ModuleNotFoundError, ImportError):
h5py = None

try:
import scipy.io as sio
except (ModuleNotFoundError, ImportError):
sio = None

__all__ = [
'SUPPORTED_FORMATS',
'save_h5',
'save_npz',
'save_pkl',
'save_mat',
'load_h5',
'load_npz',
'load_pkl',
'load_mat',
'save_as_h5',
'save_as_npz',
'save_as_pkl',
'save_as_mat',
'load_by_h5',
'load_by_npz',
'load_by_pkl',
'load_by_mat',
]

SUPPORTED_FORMATS = ['.h5', '.hdf5', '.npz', '.pkl', '.mat']


def _check(module, module_name, ext):
def check_dict_data(
a_dict: Dict,
key_type: Union[Type, Tuple[Type, ...]] = None,
val_type: Union[Type, Tuple[Type, ...]] = None,
name: str = None
):
"""Check the dict data."""
name = '' if (name is None) else f'"{name}"'
if not isinstance(a_dict, dict):
raise ValueError(f'{name} must be a dict, while we got {type(a_dict)}')
if key_type is not None:
for key, value in a_dict.items():
if not isinstance(key, key_type):
raise ValueError(f'{name} must be a dict of ({key_type}, {val_type}), '
f'while we got ({type(key)}, {type(value)})')
if val_type is not None:
for key, value in a_dict.items():
if not isinstance(value, val_type):
raise ValueError(f'{name} must be a dict of ({key_type}, {val_type}), '
f'while we got ({type(key)}, {type(value)})')


def _check_module(module, module_name, ext):
"""Check whether the required module is installed."""
if module is None:
raise errors.PackageMissingError(
'"{package}" must be installed when you want to save/load data with {ext} '
@@ -52,104 +64,329 @@ def _check_missing(variables, filename):
f'The missed variables are: {list(variables.keys())}.')


def save_h5(filename, all_vars):
_check(h5py, module_name='h5py', ext=os.path.splitext(filename))
assert isinstance(all_vars, dict)
all_vars = TensorCollector(all_vars).unique()
def _check_target(target):
from .base import Base
if not isinstance(target, Base):
raise TypeError(f'"target" must be instance of "{Base.__name__}", but we got {type(target)}')


not_found_msg = ('"{key}" is stored in {filename}. But we does '
'not find it is defined as variable in {target}.')
id_dismatch_msg = ('{key1} and {key2} is the same data in {filename}. '
'But we found they are different in {target}.')

DUPLICATE_KEY = 'duplicate_keys'
DUPLICATE_TARGET = 'duplicate_targets'


def _load(
target,
verbose: bool,
filename: str,
load_vars: dict,
duplicates: Tuple[List[str], List[str]],
remove_first_axis: bool = False
):
from brainpy import math as bm

# get variables
_check_target(target)
variables = target.vars(method='absolute', level=-1)
all_names = list(variables.keys())

# read data from file
for key in load_vars.keys():
if verbose:
print(f'Loading {key} ...')
if key not in variables:
raise KeyError(not_found_msg.format(key=key, target=target.name, filename=filename))
if remove_first_axis:
value = load_vars[key][0]
else:
value = load_vars[key]
variables[key].value = bm.asarray(value)
all_names.remove(key)

# check duplicate names
duplicate_keys = duplicates[0]
duplicate_targets = duplicates[1]
for key1, key2 in zip(duplicate_keys, duplicate_targets):
if key1 not in all_names:
raise KeyError(not_found_msg.format(key=key1, target=target.name, filename=filename))
if id(variables[key1]) != id(variables[key2]):
raise ValueError(id_dismatch_msg.format(key1=key1, key2=target, filename=filename, target=target.name))
all_names.remove(key1)

# check missing names
if len(all_names):
logger.warning(f'There are variable states missed in {filename}. '
f'The missed variables are: {all_names}.')


def _unique_and_duplicate(collector: dict):
gather = TensorCollector()
id2name = dict()
duplicates = ([], [])
for k, v in collector.items():
id_ = id(v)
if id_ not in id2name:
gather[k] = v
id2name[id_] = k
else:
k2 = id2name[id_]
duplicates[0].append(k)
duplicates[1].append(k2)
duplicates = (duplicates[0], duplicates[1])
return gather, duplicates


def save_as_h5(filename: str, variables: dict):
"""Save variables into a HDF5 file.

Parameters
----------
filename: str
The filename to save.
variables: dict
All variables to save.
"""
if not (filename.endswith('.hdf5') or filename.endswith('.h5')):
raise ValueError(f'Cannot save variables as a HDF5 file. We only support file with '
f'postfix of ".hdf5" and ".h5". But we got {filename}')

from brainpy import math as bm
import h5py

# check variables
check_dict_data(variables, name='variables')
variables, duplicates = _unique_and_duplicate(variables)

# save
f = h5py.File(filename, "w")
for key, data in all_vars.items():
f[key] = np.asarray(data.value)
for key, data in variables.items():
f[key] = bm.as_numpy(data)
if len(duplicates[0]):
f.create_dataset(DUPLICATE_TARGET, data='+'.join(duplicates[1]))
f.create_dataset(DUPLICATE_KEY, data='+'.join(duplicates[0]))
f.close()


def load_h5(filename, target, verbose=False, check=False):
global math, Base
if Base is None: from brainpy.base.base import Base
if math is None: from brainpy import math
assert isinstance(target, Base)
_check(h5py, module_name='h5py', ext=os.path.splitext(filename))
def load_by_h5(filename: str, target, verbose: bool = False):
"""Load variables in a HDF5 file.

all_vars = target.vars(method='absolute')
f = h5py.File(filename, "r")
for key in f.keys():
if verbose: print(f'Loading {key} ...')
var = all_vars.pop(key)
var[:] = math.asarray(f[key][:])
f.close()
if check: _check_missing(all_vars, filename=filename)
Parameters
----------
filename: str
The filename to load variables.
target: Base
The instance of :py:class:`~.brainpy.Base`.
verbose: bool
Whether report the load progress.
"""
if not (filename.endswith('.hdf5') or filename.endswith('.h5')):
raise ValueError(f'Cannot load variables from a HDF5 file. We only support file with '
f'postfix of ".hdf5" and ".h5". But we got {filename}')

# read data
import h5py
load_vars = dict()
with h5py.File(filename, "r") as f:
for key in f.keys():
if key in [DUPLICATE_KEY, DUPLICATE_TARGET]: continue
load_vars[key] = np.asarray(f[key])
if DUPLICATE_KEY in f:
duplicate_keys = np.asarray(f[DUPLICATE_KEY]).item().decode("utf-8").split('+')
duplicate_targets = np.asarray(f[DUPLICATE_TARGET]).item().decode("utf-8").split('+')
duplicates = (duplicate_keys, duplicate_targets)
else:
duplicates = ([], [])

# assign values
_load(target, verbose, filename, load_vars, duplicates)


def save_as_npz(filename, variables, compressed=False):
"""Save variables into a numpy file.
Parameters
----------
filename: str
The filename to store.
variables: dict
Variables to save.
compressed: bool
Whether we use the compressed mode.
"""
if not filename.endswith('.npz'):
raise ValueError(f'Cannot save variables as a .npz file. We only support file with '
f'postfix of ".npz". But we got {filename}')

from brainpy import math as bm
check_dict_data(variables, name='variables')
variables, duplicates = _unique_and_duplicate(variables)

def save_npz(filename, all_vars, compressed=False):
assert isinstance(all_vars, dict)
all_vars = TensorCollector(all_vars).unique()
all_vars = {k.replace('.', '--'): np.asarray(v.value) for k, v in all_vars.items()}
# save
variables = {k: bm.as_numpy(v) for k, v in variables.items()}
if len(duplicates[0]):
variables[DUPLICATE_KEY] = np.asarray(duplicates[0])
variables[DUPLICATE_TARGET] = np.asarray(duplicates[1])
if compressed:
np.savez_compressed(filename, **all_vars)
np.savez_compressed(filename, **variables)
else:
np.savez(filename, **all_vars)


def load_npz(filename, target, verbose=False, check=False):
global math, Base
if Base is None: from brainpy.base.base import Base
if math is None: from brainpy import math
assert isinstance(target, Base)

all_vars = target.vars(method='absolute')
np.savez(filename, **variables)


def load_by_npz(filename, target, verbose=False):
"""Load variables from a numpy file.

Parameters
----------
filename: str
The filename to load variables.
target: Base
The instance of :py:class:`~.brainpy.Base`.
verbose: bool
Whether report the load progress.
"""
if not filename.endswith('.npz'):
raise ValueError(f'Cannot load variables from a .npz file. We only support file with '
f'postfix of ".npz". But we got {filename}')

# load data
load_vars = dict()
all_data = np.load(filename)
for key in all_data.files:
if verbose: print(f'Loading {key} ...')
var = all_vars.pop(key)
var[:] = math.asarray(all_data[key])
if check: _check_missing(all_vars, filename=filename)


def save_pkl(filename, all_vars):
assert isinstance(all_vars, dict)
all_vars = TensorCollector(all_vars).unique()
targets = {k: np.asarray(v) for k, v in all_vars.items()}
f = open(filename, 'wb')
pickle.dump(targets, f, protocol=pickle.HIGHEST_PROTOCOL)
f.close()


def load_pkl(filename, target, verbose=False, check=False):
global math, Base
if Base is None: from brainpy.base.base import Base
if math is None: from brainpy import math
assert isinstance(target, Base)
f = open(filename, 'rb')
all_data = pickle.load(f)
f.close()

all_vars = target.vars(method='absolute')
for key, data in all_data.items():
if verbose: print(f'Loading {key} ...')
var = all_vars.pop(key)
var[:] = math.asarray(data)
if check: _check_missing(all_vars, filename=filename)


def save_mat(filename, all_vars):
assert isinstance(all_vars, dict)
all_vars = TensorCollector(all_vars).unique()
_check(sio, module_name='scipy', ext=os.path.splitext(filename))
all_vars = {k.replace('.', '--'): np.asarray(v.value) for k, v in all_vars.items()}
sio.savemat(filename, all_vars)
if key in [DUPLICATE_KEY, DUPLICATE_TARGET]: continue
load_vars[key] = all_data[key]
if DUPLICATE_KEY in all_data:
duplicate_keys = all_data[DUPLICATE_KEY].tolist()
duplicate_targets = all_data[DUPLICATE_TARGET].tolist()
duplicates = (duplicate_keys, duplicate_targets)
else:
duplicates = ([], [])

# assign values
_load(target, verbose, filename, load_vars, duplicates)


def save_as_pkl(filename, variables):
"""Save variables into a pickle file.

Parameters
----------
filename: str
The filename to save.
variables: dict
All variables to save.
"""
if not (filename.endswith('.pkl') or filename.endswith('.pickle')):
raise ValueError(f'Cannot save variables into a pickle file. We only support file with '
f'postfix of ".pkl" and ".pickle". But we got {filename}')

check_dict_data(variables, name='variables')
variables, duplicates = _unique_and_duplicate(variables)
import brainpy.math as bm
targets = {k: bm.as_numpy(v) for k, v in variables.items()}
if len(duplicates[0]) > 0:
targets[DUPLICATE_KEY] = np.asarray(duplicates[0])
targets[DUPLICATE_TARGET] = np.asarray(duplicates[1])
with open(filename, 'wb') as f:
pickle.dump(targets, f, protocol=pickle.HIGHEST_PROTOCOL)


def load_by_pkl(filename, target, verbose=False):
"""Load variables from a pickle file.

Parameters
----------
filename: str
The filename to load variables.
target: Base
The instance of :py:class:`~.brainpy.Base`.
verbose: bool
Whether report the load progress.
"""
if not (filename.endswith('.pkl') or filename.endswith('.pickle')):
raise ValueError(f'Cannot load variables from a pickle file. We only support file with '
f'postfix of ".pkl" and ".pickle". But we got {filename}')

# load variables
load_vars = dict()
with open(filename, 'rb') as f:
all_data = pickle.load(f)
for key, data in all_data.items():
if key in [DUPLICATE_KEY, DUPLICATE_TARGET]: continue
load_vars[key] = data
if DUPLICATE_KEY in all_data:
duplicate_keys = all_data[DUPLICATE_KEY].tolist()
duplicate_targets = all_data[DUPLICATE_TARGET].tolist()
duplicates = (duplicate_keys, duplicate_targets)
else:
duplicates = ([], [])

# assign data
_load(target, verbose, filename, load_vars, duplicates)


def save_as_mat(filename, variables):
"""Save variables into a HDF5 file.

Parameters
----------
filename: str
The filename to save.
variables: dict
All variables to save.
"""
if not filename.endswith('.mat'):
raise ValueError(f'Cannot save variables into a .mat file. We only support file with '
f'postfix of ".mat". But we got {filename}')

from brainpy import math as bm
import scipy.io as sio

check_dict_data(variables, name='variables')
variables, duplicates = _unique_and_duplicate(variables)
variables = {k: np.expand_dims(bm.as_numpy(v), axis=0) for k, v in variables.items()}
if len(duplicates[0]):
variables[DUPLICATE_KEY] = np.expand_dims(np.asarray(duplicates[0]), axis=0)
variables[DUPLICATE_TARGET] = np.expand_dims(np.asarray(duplicates[1]), axis=0)
sio.savemat(filename, variables)


def load_by_mat(filename, target, verbose=False):
"""Load variables from a numpy file.

Parameters
----------
filename: str
The filename to load variables.
target: Base
The instance of :py:class:`~.brainpy.Base`.
verbose: bool
Whether report the load progress.
"""
if not filename.endswith('.mat'):
raise ValueError(f'Cannot load variables from a .mat file. We only support file with '
f'postfix of ".mat". But we got {filename}')

def load_mat(filename, target, verbose=False, check=False):
global math, Base
if Base is None: from brainpy.base.base import Base
if math is None: from brainpy import math
assert isinstance(target, Base)
import scipy.io as sio

# load data
load_vars = dict()
all_data = sio.loadmat(filename)
all_vars = target.vars(method='absolute')
for key, data in all_data.items():
if verbose: print(f'Loading {key} ...')
var = all_vars.pop(key)
var[:] = math.asarray(data)
if check: _check_missing(all_vars, filename=filename)
if key.startswith('__'):
continue
if key in [DUPLICATE_KEY, DUPLICATE_TARGET]:
continue
load_vars[key] = data[0]
if DUPLICATE_KEY in all_data:
duplicate_keys = [a.strip() for a in all_data[DUPLICATE_KEY].tolist()[0]]
duplicate_targets = [a.strip() for a in all_data[DUPLICATE_TARGET].tolist()[0]]
duplicates = (duplicate_keys, duplicate_targets)
else:
duplicates = ([], [])

# assign values
_load(target, verbose, filename, load_vars, duplicates)

+ 3
- 2
brainpy/base/naming.py View File

@@ -44,8 +44,9 @@ def get_unique_name(type_):
return name


def clear_name_cache():
def clear_name_cache(ignore_warn=False):
"""Clear the cached names."""
_name2id.clear()
_typed_names.clear()
logger.warning(f'All named models and their ids are cleared.')
if not ignore_warn:
logger.warning(f'All named models and their ids are cleared.')

+ 1
- 1
brainpy/base/tests/test_base.py View File

@@ -28,7 +28,7 @@ class TestCollectionFunction(unittest.TestCase):

net = bp.dyn.Network(a1=A(), a2=A())
print(net.nodes(level=2))
self.assertTrue(len(net.nodes(level=0)) == 0)
self.assertTrue(len(net.nodes(level=0)) == 1)
self.assertTrue(len(net.nodes(level=0, include_self=False)) == 0)
self.assertTrue(len(net.nodes(level=1)) == (1 + 2))
self.assertTrue(len(net.nodes(level=1, include_self=False)) == 2)


+ 0
- 14
brainpy/base/tests/test_circular_reference.py View File

@@ -74,17 +74,3 @@ def test_nodes():

assert len(abs_nodes) == 3
assert len(rel_nodes) == 5


def test_ints():
A = HH(1, name='X2')
B = HH(1, name='Y2')
A.pre = B
B.pre = A

net = bp.dyn.Network(A, B)
abs_ints = net.ints(method='absolute')
rel_ints = net.ints(method='relative')
print()
pprint(abs_ints.keys())
pprint(rel_ints.keys())

+ 49
- 70
brainpy/base/tests/test_collector.py View File

@@ -28,20 +28,15 @@ class GABAa_without_Variable(bp.dyn.TwoEndConn):
# variables
self.t_last_pre_spike = bp.math.ones(self.size) * -1e7
self.s = bp.math.zeros(self.size)
self.g = bp.dyn.ConstantDelay(size=self.size, delay=delay)

@bp.odeint
def int_s(self, s, t, TT):
return self.alpha * TT * (1 - s) - self.beta * s
self.int_s = bp.odeint(lambda s, t, TT: self.alpha * TT * (1 - s) - self.beta * s)

def update(self, t, dt):
def update(self, tdi):
spike = bp.math.reshape(self.pre.spikes, (self.pre.num, 1)) * self.conn_mat
self.t_last_pre_spike[:] = bp.math.where(spike, t, self.t_last_pre_spike)
TT = ((t - self.t_last_pre_spike) < self.T_duration) * self.T
self.s[:] = self.int_s(self.s, t, TT)
self.g.push(self.g_max * self.s)
g = self.g.pull()
self.post.inputs -= bp.math.sum(g, axis=0) * (self.post.V - self.E)
self.t_last_pre_spike[:] = bp.math.where(spike, tdi.t, self.t_last_pre_spike)
TT = ((tdi.t - self.t_last_pre_spike) < self.T_duration) * self.T
self.s[:] = self.int_s(self.s, tdi.t, TT)
self.post.inputs -= bp.math.sum(self.s, axis=0) * (self.post.V - self.E)


class HH_without_Variable(bp.dyn.NeuGroup):
@@ -67,8 +62,9 @@ class HH_without_Variable(bp.dyn.NeuGroup):
self.inputs = bp.math.zeros(self.num)
self.spikes = bp.math.zeros(self.num, dtype=bp.math.bool_)

@bp.odeint
def integral(self, V, h, n, t, Iext):
self.integral = bp.odeint(self.derivative)

def derivative(self, V, h, n, t, Iext):
alpha = 0.07 * bp.math.exp(-(V + 58) / 20)
beta = 1 / (bp.math.exp(-0.1 * (V + 28)) + 1)
dhdt = self.phi * (alpha * (1 - h) - beta * h)
@@ -87,8 +83,8 @@ class HH_without_Variable(bp.dyn.NeuGroup):

return dVdt, dhdt, dndt

def update(self, t, dt):
V, h, n = self.integral(self.V, self.h, self.n, t, self.inputs)
def update(self, tdi):
V, h, n = self.integral(self.V, self.h, self.n, tdi.t, self.inputs)
self.spikes[:] = bp.math.logical_and(self.V < self.V_th, V >= self.V_th)
self.V[:] = V
self.h[:] = h
@@ -102,11 +98,11 @@ def test_subset_integrator():
syn.g_max = 0.1 / neu.num
net = bp.dyn.Network(neu, syn)

ints = net.ints()
ints = net.nodes(level=-1).subset(bp.integrators.Integrator)
print()
print(ints)

ode_ints = ints.subset(bp.integrators.ODEIntegrator)
ode_ints = ints.subset(bp.integrators.ODEIntegrator).unique()
print(ode_ints)
assert len(ode_ints) == 2

@@ -143,8 +139,9 @@ class HH_with_Variable(bp.dyn.NeuGroup):
self.inputs = bp.math.Variable(bp.math.zeros(self.num))
self.spikes = bp.math.Variable(bp.math.zeros(self.num, dtype=bp.math.bool_))

@bp.odeint
def integral(self, V, h, n, t, Iext):
self.integral = bp.odeint(self.derivative)

def derivative(self, V, h, n, t, Iext):
alpha = 0.07 * bp.math.exp(-(V + 58) / 20)
beta = 1 / (bp.math.exp(-0.1 * (V + 28)) + 1)
dhdt = self.phi * (alpha * (1 - h) - beta * h)
@@ -163,8 +160,8 @@ class HH_with_Variable(bp.dyn.NeuGroup):

return dVdt, dhdt, dndt

def update(self, t, dt):
V, h, n = self.integral(self.V, self.h, self.n, t, self.inputs)
def update(self, tdi):
V, h, n = self.integral(self.V, self.h, self.n, tdi.t, self.inputs)
self.spikes[:] = bp.math.logical_and(self.V < self.V_th, V >= self.V_th)
self.V[:] = V
self.h[:] = h
@@ -187,22 +184,11 @@ def test_neu_nodes_1():
neu = HH_with_Variable(10)
print()
print(neu.nodes().keys())
assert len(neu.nodes()) == 1
assert len(neu.nodes(level=-1, include_self=False)) == 1

print()
print(neu.nodes(method='relative').keys())
assert len(neu.nodes(method='relative')) == 1


def test_neu_ints_1():
neu = HH_with_Variable(10)
print()
print(neu.ints().keys())
assert len(neu.ints()) == 1

print()
print(neu.ints(method='relative').keys())
assert len(neu.ints(method='relative')) == 1
assert len(neu.nodes(method='relative', include_self=False)) == 1


class GABAa_with_Variable(bp.dyn.TwoEndConn):
@@ -227,20 +213,14 @@ class GABAa_with_Variable(bp.dyn.TwoEndConn):
# variables
self.t_last_pre_spike = bp.math.Variable(bp.math.ones(self.size) * -1e7)
self.s = bp.math.Variable(bp.math.zeros(self.size))
self.g = bp.dyn.ConstantDelay(size=self.size, delay=delay)

@bp.odeint
def int_s(self, s, t, TT):
return self.alpha * TT * (1 - s) - self.beta * s
self.int_s = bp.odeint(lambda s, t, TT: self.alpha * TT * (1 - s) - self.beta * s)

def update(self, t, _i):
def update(self, tdi):
spike = bp.math.reshape(self.pre.spikes, (self.pre.num, 1)) * self.conn_mat
self.t_last_pre_spike[:] = bp.math.where(spike, t, self.t_last_pre_spike)
TT = ((t - self.t_last_pre_spike) < self.T_duration) * self.T
self.s[:] = self.int_s(self.s, t, TT)
self.g.push(self.g_max * self.s)
g = self.g.pull()
self.post.inputs -= bp.math.sum(g, axis=0) * (self.post.V - self.E)
self.t_last_pre_spike[:] = bp.math.where(spike, tdi.t, self.t_last_pre_spike)
TT = ((tdi.t - self.t_last_pre_spike) < self.T_duration) * self.T
self.s[:] = self.int_s(self.s, tdi.t, TT)
self.post.inputs -= bp.math.sum(self.g_max * self.s, axis=0) * (self.post.V - self.E)


def test_net_1():
@@ -251,29 +231,20 @@ def test_net_1():
# variables
print()
pprint(list(net.vars().keys()))
assert len(net.vars()) == 3
assert len(net.vars()) == 0

print()
pprint(list(net.vars(method='relative').keys()))
assert len(net.vars(method='relative')) == 3
assert len(net.vars(method='relative')) == 0

# nodes
print()
pprint(list(net.nodes().keys()))
assert len(net.nodes()) == 4

print()
pprint(list(net.nodes(method='relative').keys()))
assert len(net.nodes(method='relative')) == 5

# ints
print()
pprint(list(net.ints().keys()))
assert len(net.ints()) == 2
pprint(list(net.nodes().unique().keys()))
# assert len(net.nodes()) == 8

print()
pprint(list(net.ints(method='relative').keys()))
assert len(net.ints(method='relative')) == 3
pprint(list(net.nodes(method='relative').unique().keys()))
# assert len(net.nodes(method='relative')) == 12


def test_net_vars_2():
@@ -293,17 +264,25 @@ def test_net_vars_2():
# nodes
print()
pprint(list(net.nodes().keys()))
assert len(net.nodes()) == 4
# assert len(net.nodes()) == 8

print()
pprint(list(net.nodes(method='relative').keys()))
assert len(net.nodes(method='relative')) == 5
# assert len(net.nodes(method='relative')) == 6

# ints
print()
pprint(list(net.ints().keys()))
assert len(net.ints()) == 2

print()
pprint(list(net.ints(method='relative').keys()))
assert len(net.ints(method='relative')) == 3
def test_hidden_variables():
class BPClass(bp.base.Base):
_excluded_vars = ('_rng_', )

def __init__(self):
super(BPClass, self).__init__()

self._rng_ = bp.math.random.RandomState()
self.rng = bp.math.random.RandomState()

model = BPClass()

print(model.vars(level=-1).keys())
assert len(model.vars(level=-1)) == 1


+ 170
- 0
brainpy/base/tests/test_io.py View File

@@ -0,0 +1,170 @@
# -*- coding: utf-8 -*-


import brainpy as bp
import brainpy.math as bm
import unittest


class TestIO1(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestIO1, self).__init__(*args, **kwargs)

rng = bm.random.RandomState()

class IO1(bp.dyn.DynamicalSystem):
def __init__(self):
super(IO1, self).__init__()

self.a = bm.Variable(bm.zeros(1))
self.b = bm.Variable(bm.ones(3))
self.c = bm.Variable(bm.ones((3, 4)))
self.d = bm.Variable(bm.ones((2, 3, 4)))

class IO2(bp.dyn.DynamicalSystem):
def __init__(self):
super(IO2, self).__init__()

self.a = bm.Variable(rng.rand(3))
self.b = bm.Variable(rng.randn(10))

io1 = IO1()
io2 = IO2()
io1.a2 = io2.a
io1.b2 = io2.b
io2.a2 = io1.a
io2.b2 = io2.b

self.net = bp.dyn.Container(io1, io2)

print(self.net.vars().keys())
print(self.net.vars().unique().keys())

def test_h5(self):
bp.base.save_as_h5('io_test_tmp.h5', self.net.vars())
bp.base.load_by_h5('io_test_tmp.h5', self.net, verbose=True)

bp.base.save_as_h5('io_test_tmp.hdf5', self.net.vars())
bp.base.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True)

def test_h5_postfix(self):
with self.assertRaises(ValueError):
bp.base.save_as_h5('io_test_tmp.h52', self.net.vars())
with self.assertRaises(ValueError):
bp.base.load_by_h5('io_test_tmp.h52', self.net, verbose=True)

def test_npz(self):
bp.base.save_as_npz('io_test_tmp.npz', self.net.vars())
bp.base.load_by_npz('io_test_tmp.npz', self.net, verbose=True)

bp.base.save_as_npz('io_test_tmp_compressed.npz', self.net.vars(), compressed=True)
bp.base.load_by_npz('io_test_tmp_compressed.npz', self.net, verbose=True)

def test_npz_postfix(self):
with self.assertRaises(ValueError):
bp.base.save_as_npz('io_test_tmp.npz2', self.net.vars())
with self.assertRaises(ValueError):
bp.base.load_by_npz('io_test_tmp.npz2', self.net, verbose=True)

def test_pkl(self):
bp.base.save_as_pkl('io_test_tmp.pkl', self.net.vars())
bp.base.load_by_pkl('io_test_tmp.pkl', self.net, verbose=True)

bp.base.save_as_pkl('io_test_tmp.pickle', self.net.vars())
bp.base.load_by_pkl('io_test_tmp.pickle', self.net, verbose=True)

def test_pkl_postfix(self):
with self.assertRaises(ValueError):
bp.base.save_as_pkl('io_test_tmp.pkl2', self.net.vars())
with self.assertRaises(ValueError):
bp.base.load_by_pkl('io_test_tmp.pkl2', self.net, verbose=True)

def test_mat(self):
bp.base.save_as_mat('io_test_tmp.mat', self.net.vars())
bp.base.load_by_mat('io_test_tmp.mat', self.net, verbose=True)

def test_mat_postfix(self):
with self.assertRaises(ValueError):
bp.base.save_as_mat('io_test_tmp.mat2', self.net.vars())
with self.assertRaises(ValueError):
bp.base.load_by_mat('io_test_tmp.mat2', self.net, verbose=True)


class TestIO2(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestIO2, self).__init__(*args, **kwargs)

rng = bm.random.RandomState()

class IO1(bp.dyn.DynamicalSystem):
def __init__(self):
super(IO1, self).__init__()

self.a = bm.Variable(bm.zeros(1))
self.b = bm.Variable(bm.ones(3))
self.c = bm.Variable(bm.ones((3, 4)))
self.d = bm.Variable(bm.ones((2, 3, 4)))

class IO2(bp.dyn.DynamicalSystem):
def __init__(self):
super(IO2, self).__init__()

self.a = bm.Variable(rng.rand(3))
self.b = bm.Variable(rng.randn(10))

io1 = IO1()
io2 = IO2()

self.net = bp.dyn.Container(io1, io2)

print(self.net.vars().keys())
print(self.net.vars().unique().keys())

def test_h5(self):
bp.base.save_as_h5('io_test_tmp.h5', self.net.vars())
bp.base.load_by_h5('io_test_tmp.h5', self.net, verbose=True)

bp.base.save_as_h5('io_test_tmp.hdf5', self.net.vars())
bp.base.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True)

def test_h5_postfix(self):
with self.assertRaises(ValueError):
bp.base.save_as_h5('io_test_tmp.h52', self.net.vars())
with self.assertRaises(ValueError):
bp.base.load_by_h5('io_test_tmp.h52', self.net, verbose=True)

def test_npz(self):
bp.base.save_as_npz('io_test_tmp.npz', self.net.vars())
bp.base.load_by_npz('io_test_tmp.npz', self.net, verbose=True)

bp.base.save_as_npz('io_test_tmp_compressed.npz', self.net.vars(), compressed=True)
bp.base.load_by_npz('io_test_tmp_compressed.npz', self.net, verbose=True)

def test_npz_postfix(self):
with self.assertRaises(ValueError):
bp.base.save_as_npz('io_test_tmp.npz2', self.net.vars())
with self.assertRaises(ValueError):
bp.base.load_by_npz('io_test_tmp.npz2', self.net, verbose=True)

def test_pkl(self):
bp.base.save_as_pkl('io_test_tmp.pkl', self.net.vars())
bp.base.load_by_pkl('io_test_tmp.pkl', self.net, verbose=True)

bp.base.save_as_pkl('io_test_tmp.pickle', self.net.vars())
bp.base.load_by_pkl('io_test_tmp.pickle', self.net, verbose=True)

def test_pkl_postfix(self):
with self.assertRaises(ValueError):
bp.base.save_as_pkl('io_test_tmp.pkl2', self.net.vars())
with self.assertRaises(ValueError):
bp.base.load_by_pkl('io_test_tmp.pkl2', self.net, verbose=True)

def test_mat(self):
bp.base.save_as_mat('io_test_tmp.mat', self.net.vars())
bp.base.load_by_mat('io_test_tmp.mat', self.net, verbose=True)

def test_mat_postfix(self):
with self.assertRaises(ValueError):
bp.base.save_as_mat('io_test_tmp.mat2', self.net.vars())
with self.assertRaises(ValueError):
bp.base.load_by_mat('io_test_tmp.mat2', self.net, verbose=True)

+ 0
- 27
brainpy/compat/__init__.py View File

@@ -1,27 +0,0 @@
# -*- coding: utf-8 -*-


__all__ = [
# modules
'brainobjects', 'layers', 'models',

# brainobjects
'DynamicalSystem', 'Container', 'Network',
'ConstantDelay', 'NeuGroup', 'TwoEndConn',

# integrators
'set_default_odeint', 'set_default_sdeint',
'get_default_odeint', 'get_default_sdeint',

# monitor
'Monitor',

# runners
'IntegratorRunner', 'DSRunner', 'StructRunner', 'ReportRunner'
]

from . import brainobjects, layers, models
from .brainobjects import *
from .integrators import *
from .monitor import *
from .runners import *

+ 0
- 92
brainpy/compat/brainobjects.py View File

@@ -1,92 +0,0 @@
# -*- coding: utf-8 -*-

import warnings

from brainpy import dyn

__all__ = [
'DynamicalSystem',
'Container',
'Network',
'ConstantDelay',
'NeuGroup',
'TwoEndConn',
]


class DynamicalSystem(dyn.DynamicalSystem):
"""Dynamical System.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.DynamicalSystem" instead.
"""
def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.DynamicalSystem" instead. '
'"brainpy.DynamicalSystem" is deprecated since '
'version 2.0.3', DeprecationWarning)
super(DynamicalSystem, self).__init__(*args, **kwargs)


class Container(dyn.Container):
"""Container.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.Container" instead.
"""
def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.Container" instead. '
'"brainpy.Container" is deprecated since '
'version 2.0.3', DeprecationWarning)
super(Container, self).__init__(*args, **kwargs)


class Network(dyn.Network):
"""Network.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.Network" instead.
"""
def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.Network" instead. '
'"brainpy.Network" is deprecated since '
'version 2.0.3', DeprecationWarning)
super(Network, self).__init__(*args, **kwargs)


class ConstantDelay(dyn.ConstantDelay):
"""Constant Delay.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.ConstantDelay" instead.
"""
def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.ConstantDelay" instead. '
'"brainpy.ConstantDelay" is deprecated since '
'version 2.0.3', DeprecationWarning)
super(ConstantDelay, self).__init__(*args, **kwargs)


class NeuGroup(dyn.NeuGroup):
"""Neuron group.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.NeuGroup" instead.
"""
def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.NeuGroup" instead. '
'"brainpy.NeuGroup" is deprecated since '
'version 2.0.3', DeprecationWarning)
super(NeuGroup, self).__init__(*args, **kwargs)


class TwoEndConn(dyn.TwoEndConn):
"""Two-end synaptic connection.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.TwoEndConn" instead.
"""
def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.TwoEndConn" instead. '
'"brainpy.TwoEndConn" is deprecated since '
'version 2.0.3', DeprecationWarning)
super(TwoEndConn, self).__init__(*args, **kwargs)

+ 0
- 60
brainpy/compat/integrators.py View File

@@ -1,60 +0,0 @@
# -*- coding: utf-8 -*-

import warnings

from brainpy.integrators import ode, sde

__all__ = [
'set_default_odeint',
'set_default_sdeint',
'get_default_odeint',
'get_default_sdeint',
]


def set_default_odeint(method):
"""Set default ode integrator.

.. deprecated:: 2.1.0
Please use "brainpy.ode.set_default_odeint" instead.
"""
warnings.warn('Please use "brainpy.ode.set_default_odeint" instead. '
'"brainpy.set_default_odeint" is deprecated since '
'version 2.1.0', DeprecationWarning)
ode.set_default_odeint(method)


def get_default_odeint():
"""Get default ode integrator.

.. deprecated:: 2.1.0
Please use "brainpy.ode.get_default_odeint" instead.
"""
warnings.warn('Please use "brainpy.ode.get_default_odeint" instead. '
'"brainpy.get_default_odeint" is deprecated since '
'version 2.1.0', DeprecationWarning)
ode.get_default_odeint()


def set_default_sdeint(method):
"""Set default sde integrator.

.. deprecated:: 2.1.0
Please use "brainpy.ode.set_default_sdeint" instead.
"""
warnings.warn('Please use "brainpy.sde.set_default_sdeint" instead. '
'"brainpy.set_default_sdeint" is deprecated since '
'version 2.1.0', DeprecationWarning)
sde.set_default_sdeint(method)


def get_default_sdeint():
"""Get default sde integrator.

.. deprecated:: 2.1.0
Please use "brainpy.ode.get_default_sdeint" instead.
"""
warnings.warn('Please use "brainpy.sde.get_default_sdeint" instead. '
'"brainpy.get_default_sdeint" is deprecated since '
'version 2.1.0', DeprecationWarning)
sde.get_default_sdeint()

+ 0
- 61
brainpy/compat/layers.py View File

@@ -1,61 +0,0 @@
# -*- coding: utf-8 -*-

import warnings

import jax.numpy as jnp
import numpy as onp

import brainpy.math as bm
from brainpy.base.base import Base

__all__ = [
'Module',
]


def _check_args(args):
if args is None:
return tuple()
elif isinstance(args, tuple):
return args
else:
return (args,)


class Module(Base):
"""Basic module class.

.. deprecated:: 2.1.0
"""

@staticmethod
def get_param(param, size):
return bm.TrainVar(Module.init_param(param, size))

@staticmethod
def init_param(param, size):
if param is None:
return None
if callable(param):
param = param(size)
elif isinstance(param, onp.ndarray):
param = bm.asarray(param)
elif isinstance(param, (bm.JaxArray, jnp.ndarray)):
pass
else:
raise ValueError(f'Unknown param type {type(param)}: {param}')
assert param.shape == size, f'"param.shape" is not the required size {size}'
return param

def __init__(self, name=None): # initialize parameters
warnings.warn('Please use "brainpy.rnns.Module" instead. '
'"brainpy.layers.Module" is deprecated since '
'version 2.1.0.', DeprecationWarning)
super(Module, self).__init__(name=name)

def __call__(self, *args, **kwargs): # initialize variables
return self.call(*args, **kwargs)

def call(self, *args, **kwargs):
raise NotImplementedError


+ 0
- 98
brainpy/compat/models.py View File

@@ -1,98 +0,0 @@
# -*- coding: utf-8 -*-

import warnings

from brainpy.dyn import neurons, synapses

__all__ = [
'LIF',
'AdExIF',
'Izhikevich',
'ExpCOBA',
'ExpCUBA',
'DeltaSynapse',
]


class LIF(neurons.LIF):
"""LIF neuron model.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.LIF" instead.
"""

def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.LIF" instead. '
'"brainpy.models.LIF" is deprecated since '
'version 2.1.0', DeprecationWarning)
super(LIF, self).__init__(*args, **kwargs)


class AdExIF(neurons.AdExIF):
"""AdExIF neuron model.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.AdExIF" instead.
"""

def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.AdExIF" instead. '
'"brainpy.models.AdExIF" is deprecated since '
'version 2.1.0', DeprecationWarning)
super(AdExIF, self).__init__(*args, **kwargs)


class Izhikevich(neurons.Izhikevich):
"""Izhikevich neuron model.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.Izhikevich" instead.
"""

def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.Izhikevich" instead. '
'"brainpy.models.Izhikevich" is deprecated since '
'version 2.1.0', DeprecationWarning)
super(Izhikevich, self).__init__(*args, **kwargs)


class ExpCOBA(synapses.ExpCOBA):
"""ExpCOBA synapse model.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.ExpCOBA" instead.
"""

def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.ExpCOBA" instead. '
'"brainpy.models.ExpCOBA" is deprecated since '
'version 2.1.0', DeprecationWarning)
super(ExpCOBA, self).__init__(*args, **kwargs)


class ExpCUBA(synapses.ExpCUBA):
"""ExpCUBA synapse model.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.ExpCUBA" instead.
"""

def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.ExpCUBA" instead. '
'"brainpy.models.ExpCUBA" is deprecated since '
'version 2.1.0', DeprecationWarning)
super(ExpCUBA, self).__init__(*args, **kwargs)


class DeltaSynapse(synapses.DeltaSynapse):
"""Delta synapse model.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.DeltaSynapse" instead.
"""

def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.DeltaSynapse" instead. '
'"brainpy.models.DeltaSynapse" is deprecated since '
'version 2.1.0', DeprecationWarning)
super(DeltaSynapse, self).__init__(*args, **kwargs)

+ 0
- 21
brainpy/compat/monitor.py View File

@@ -1,21 +0,0 @@
# -*- coding: utf-8 -*-
import warnings

from brainpy.running import monitor

__all__ = [
'Monitor'
]


class Monitor(monitor.Monitor):
"""Monitor class.

.. deprecated:: 2.1.0
Please use "brainpy.running.Monitor" instead.
"""
def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.running.Monitor" instead. '
'"brainpy.Monitor" is deprecated since version 2.1.0.',
DeprecationWarning)
super(Monitor, self).__init__(*args, **kwargs)

+ 0
- 65
brainpy/compat/runners.py View File

@@ -1,65 +0,0 @@
# -*- coding: utf-8 -*-

import warnings

from brainpy.dyn import runners as dyn_runner
from brainpy.integrators import runner as intg_runner

__all__ = [
'IntegratorRunner',
'DSRunner',
'StructRunner',
'ReportRunner'
]


class IntegratorRunner(intg_runner.IntegratorRunner):
"""Integrator runner class.

.. deprecated:: 2.1.0
Please use "brainpy.integrators.IntegratorRunner" instead.
"""
def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.integrators.IntegratorRunner" instead. '
'"brainpy.IntegratorRunner" is deprecated since '
'version 2.1.0', DeprecationWarning)
super(IntegratorRunner, self).__init__(*args, **kwargs)


class DSRunner(dyn_runner.DSRunner):
"""Dynamical system runner class.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.DSRunner" instead.
"""
def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.DSRunner" instead. '
'"brainpy.DSRunner" is deprecated since '
'version 2.1.0', DeprecationWarning)
super(DSRunner, self).__init__(*args, **kwargs)


class StructRunner(dyn_runner.StructRunner):
"""Dynamical system runner class.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.StructRunner" instead.
"""
def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.StructRunner" instead. '
'"brainpy.StructRunner" is deprecated since '
'version 2.1.0', DeprecationWarning)
super(StructRunner, self).__init__(*args, **kwargs)


class ReportRunner(dyn_runner.ReportRunner):
"""Dynamical system runner class.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.ReportRunner" instead.
"""
def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.ReportRunner" instead. '
'"brainpy.ReportRunner" is deprecated since '
'version 2.1.0', DeprecationWarning)
super(ReportRunner, self).__init__(*args, **kwargs)

+ 354
- 127
brainpy/connect/random_conn.py View File

@@ -3,7 +3,7 @@
import numpy as np

from brainpy.errors import ConnectorError
from brainpy.tools.others import numba_seed, numba_jit, SUPPORT_NUMBA, format_seed
from .base import *

__all__ = [
@@ -11,6 +11,7 @@ __all__ = [
'FixedPreNum',
'FixedPostNum',
'GaussianProb',
'ProbDist',

'SmallWorld',
'ScaleFreeBA',
@@ -19,85 +20,76 @@ __all__ = [
]


# @tools.numba_jit
def _random_prob_conn(rng, pre_i, num_post, prob, include_self):
p = rng.random(num_post) <= prob
if (not include_self) and pre_i < num_post:
p[pre_i] = False
conn_j = np.asarray(np.where(p)[0], dtype=IDX_DTYPE)
return conn_j


class FixedProb(TwoEndConnector):
"""Connect the post-synaptic neurons with fixed probability.

Parameters
----------
prob : float
The conn probability.
The conn probability.
pre_ratio: float
The ratio of pre-synaptic neurons to connect.
include_self : bool
Whether create (i, i) conn?
Whether create (i, i) conn?
seed : optional, int
Seed the random generator.
Seed the random generator.
"""

def __init__(self, prob, include_self=True, seed=None):
def __init__(self, prob, pre_ratio=1., include_self=True, seed=None):
super(FixedProb, self).__init__()
assert 0. <= prob <= 1.
self.prob = prob
self.pre_ratio = pre_ratio
self.include_self = include_self
self.seed = seed
self.rng = np.random.RandomState(seed=seed)
self.seed = format_seed(seed)
self.rng = np.random.RandomState(seed=self.seed)

rng = np.random if SUPPORT_NUMBA else self.rng

def _connect(pre_i, num_post):
if rng.random() < pre_ratio:
p = rng.random(num_post) <= prob
if (not include_self) and pre_i < num_post:
p[pre_i] = False
return np.where(p)[0]

self._connect = numba_jit(_connect)

def build_conn(self):
# seed
self.seed = self.rng.randint(1, int(1e7))
if SUPPORT_NUMBA: numba_seed(self.seed)

# make connections
ind = []
count = np.zeros(self.pre_num, dtype=IDX_DTYPE)

for i in range(self.pre_num):
posts = _random_prob_conn(self.rng, pre_i=i, num_post=self.post_num,
prob=self.prob, include_self=self.include_self)
ind.append(posts)
count[i] = len(posts)

ind = np.concatenate(ind)
posts = self._connect(pre_i=i, num_post=self.post_num)
if posts is not None:
ind.append(posts)
count[i] = len(posts)
ind = np.concatenate(ind) if len(ind) > 0 else np.asarray([], dtype=IDX_DTYPE)
indptr = np.concatenate(([0], count)).cumsum()

return 'csr', (ind, indptr)


# @tools.numba_jit
def _fixed_num_prob(rng, num_need, num_total, i=0, include_self=False):
prob = rng.random(num_total)
if not include_self and i <= num_total:
prob[i] = 1.
neu_idx = np.argsort(prob)[:num_need]
return np.asarray(neu_idx, dtype=IDX_DTYPE)


class FixedPreNum(TwoEndConnector):
"""Connect the pre-synaptic neurons with fixed number for each post-synaptic neuron.
class FixedNum(TwoEndConnector):
"""Connect with fixed number for each pre- or post-synaptic neuron.

Parameters
----------
num : float, int
The connection probability (if "num" is float) or the fixed number of
connectivity (if "num" is int).
The conn probability (if "num" is float) or the fixed number of
connectivity (if "num" is int).
include_self : bool
Whether create (i, i) conn ?
Whether create (i, i) conn ?
seed : None, int
Seed the random generator.

- ``matrix``: This method will create a big matrix, then, the connectivity is constructed
from this matrix :math:`(N_{pre}, N_{post})`. In a large network, this method will
consume huge memories, including a matrix: :math:`(N_{pre}, N_{post})`, two vectors:
:math:`2 * N_{need} * N_{post}`.
- ``iter``: This method will iteratively build the synaptic connections. It has the
minimum pressure of memory consuming, only :math:`2 * N_{need} * N_{post}`
(``i`` and ``j`` vectors).
Seed the random generator.
"""

def __init__(self, num, include_self=True, seed=None):
super(FixedPreNum, self).__init__()
super(FixedNum, self).__init__()
if isinstance(num, int):
assert num >= 0, '"num" must be a non-negative integer.'
elif isinstance(num, float):
@@ -105,9 +97,32 @@ class FixedPreNum(TwoEndConnector):
else:
raise ConnectorError(f'Unknown type: {type(num)}')
self.num = num
self.seed = seed
self.seed = format_seed(seed)
self.include_self = include_self
self.rng = np.random.RandomState(seed=seed)
self.rng = np.random.RandomState(seed=self.seed)
rng = np.random if SUPPORT_NUMBA else self.rng

def _fixed_num_prob(num_need, num_total, i=0):
prob = rng.random(num_total)
if not include_self and i <= num_total:
prob[i] = 1.
neu_idx = np.argsort(prob)[:num_need]
return np.asarray(neu_idx, dtype=IDX_DTYPE)

self._connect = numba_jit(_fixed_num_prob)


class FixedPreNum(FixedNum):
"""Connect the pre-synaptic neurons with fixed number for each post-synaptic neuron.

Parameters
----------
num : float, int
The connection probability (if "num" is float) or the fixed number of
connectivity (if "num" is int).
include_self : bool
Whether create (i, i) conn ?
"""

def build_conn(self):
# check
@@ -119,18 +134,22 @@ class FixedPreNum(TwoEndConnector):
assert 0. <= self.num <= 1., f'"num" must be in [0., 1.), but got {self.num}'
num = int(self.pre_num * self.num)

# seed
self.seed = self.rng.randint(1, int(1e7))
numba_seed(self.seed)

# make connections
pre_ids = []
for i in range(self.post_num):
pres = _fixed_num_prob(rng=self.rng, num_need=num, num_total=self.pre_num,
i=i, include_self=self.include_self)
pres = self._connect(num_need=num, num_total=self.pre_num, i=i)
pre_ids.append(pres)
pre_ids = np.concatenate(pre_ids)
pre_ids = np.concatenate(pre_ids) if len(pre_ids) > 0 else np.asarray([], dtype=IDX_DTYPE)
post_ids = np.repeat(np.arange(self.post_num), num)

return 'ij', (pre_ids, post_ids)


class FixedPostNum(TwoEndConnector):
class FixedPostNum(FixedNum):
"""Connect the post-synaptic neurons with fixed number for each pre-synaptic neuron.

Parameters
@@ -142,49 +161,27 @@ class FixedPostNum(TwoEndConnector):
Whether create (i, i) conn ?
seed : None, int
Seed the random generator.
method : str
The method used to create the connection.

- ``matrix``: This method will create a big matrix, then, the connectivity is constructed
from this matrix :math:`(N_{pre}, N_{post})`. In a large network, this method will
consume huge memories, including a matrix: :math:`(N_{pre}, N_{post})`, two vectors:
:math:`2 * N_{need} * N_{pre}`.
- ``iter``: This method will iteratively build the synaptic connections. It has the
minimum pressure of memory consuming, only :math:`2 * N_{need} * N_{pre}`
(``i`` and ``j`` vectors).
"""

def __init__(self, num, include_self=True, seed=None):
super(FixedPostNum, self).__init__()
if isinstance(num, int):
assert num >= 0, '"num" must be a non-negative integer.'
elif isinstance(num, float):
assert 0. <= num <= 1., '"num" must be in [0., 1.).'
else:
raise ConnectorError(f'Unknown type: {type(num)}')
self.num = num
self.seed = seed
self.include_self = include_self
self.rng = np.random.RandomState(seed=seed)

def build_conn(self):
# check
if isinstance(self.num, int):
assert 0 <= self.num <= self.post_num, f'"num" must be smaller than "self.post_num", ' \
f'but got {self.num} > {self.post_num}'
prob = self.num / self.post_num
num = self.num
else:
assert 0. <= self.num <= 1., f'"num" must be in [0., 1.), but got {self.num}'
num = int(self.post_num * self.num)
prob = self.num

# seed
self.seed = self.rng.randint(1, int(1e7))
numba_seed(self.seed)

# make connections
post_ids = [] # i.e. post_ids
for i in range(self.pre_num):
posts = _fixed_num_prob(rng=self.rng, num_need=num, num_total=self.post_num,
i=i, include_self=self.include_self)
posts = self._connect(num_need=num, num_total=self.post_num, i=i)
post_ids.append(posts)

post_ids = np.concatenate(post_ids)
count = np.ones(self.pre_num, dtype=IDX_DTYPE) * num
indptr = np.concatenate(([0], count)).cumsum()
@@ -230,16 +227,23 @@ class GaussianProb(OneEndConnector):
The random seed.
"""

def __init__(self, sigma, encoding_values=None, normalize=True, include_self=True,
periodic_boundary=False, seed=None):
def __init__(
self,
sigma: float,
encoding_values=None,
normalize: bool = True,
include_self: bool = True,
periodic_boundary: bool = False,
seed: int = None
):
super(GaussianProb, self).__init__()
self.sigma = sigma
self.encoding_values = encoding_values
self.normalize = normalize
self.include_self = include_self
self.periodic_boundary = periodic_boundary
self.seed = seed
self.rng = np.random.RandomState(seed)
self.seed = format_seed(seed)
self.rng = np.random.RandomState(self.seed)

def build_conn(self):
# value range to encode
@@ -304,22 +308,6 @@ class GaussianProb(OneEndConnector):
return 'mat', conn_mat


# @tools.numba_jit
def _smallworld_rewire(prob, i, all_j, include_self):
if np.random.random(1) < prob:
non_connected = np.where(all_j == False)[0]
if len(non_connected) <= 1:
return -1
# Enforce no self-loops or multiple edges
w = np.random.choice(non_connected)
while (not include_self) and w == i:
# non_connected.remove(w)
w = np.random.choice(non_connected)
return w
else:
return -1


class SmallWorld(TwoEndConnector):
"""Build a Watts–Strogatz small-world graph.

@@ -351,16 +339,47 @@ class SmallWorld(TwoEndConnector):
Nature, 393, pp. 440--442, 1998.
"""

def __init__(self, num_neighbor, prob, directed=False, include_self=False):
def __init__(
self,
num_neighbor,
prob,
directed=False,
include_self=False,
seed=None
):
super(SmallWorld, self).__init__()
self.prob = prob
self.directed = directed
self.num_neighbor = num_neighbor
self.include_self = include_self

self.seed = format_seed(seed)
self.rng = np.random.RandomState(seed=self.seed)
rng = np.random if SUPPORT_NUMBA else self.rng

def _smallworld_rewire(i, all_j):
if rng.random(1) < prob:
non_connected = np.where(np.logical_not(all_j))[0]
if len(non_connected) <= 1:
return -1
# Enforce no self-loops or multiple edges
w = rng.choice(non_connected)
while (not include_self) and w == i:
# non_connected.remove(w)
w = rng.choice(non_connected)
return w
else:
return -1

self._connect = numba_jit(_smallworld_rewire)

def build_conn(self):
assert self.pre_size == self.post_size

# seed
self.seed = self.rng.randint(1, int(1e7))
numba_seed(self.seed)

if isinstance(self.pre_size, int) or (isinstance(self.pre_size, (tuple, list)) and len(self.pre_size) == 1):
num_node = self.pre_num

@@ -386,18 +405,18 @@ class SmallWorld(TwoEndConnector):
if self.directed:
# inner loop in node order
for u, v in zip(nodes, targets):
w = _smallworld_rewire(prob=self.prob, i=u, all_j=conn[u], include_self=self.include_self)
w = self._connect(prob=self.prob, i=u, all_j=conn[u])
if w != -1:
conn[u, v] = False
conn[u, w] = True
w = _smallworld_rewire(prob=self.prob, i=u, all_j=conn[:, u], include_self=self.include_self)
w = self._connect(prob=self.prob, i=u, all_j=conn[:, u])
if w != -1:
conn[v, u] = False
conn[w, u] = True
else:
# inner loop in node order
for u, v in zip(nodes, targets):
w = _smallworld_rewire(prob=self.prob, i=u, all_j=conn[u], include_self=self.include_self)
w = self._connect(i=u, all_j=conn[u])
if w != -1:
conn[u, v] = False
conn[v, u] = False
@@ -410,19 +429,19 @@ class SmallWorld(TwoEndConnector):
return 'mat', conn


def _random_subset(seq, m, rng):
"""Return m unique elements from seq.
This differs from random.sample which can return repeated
elements if seq holds repeated elements.
Note: rng is a random.Random or numpy.random.RandomState instance.
"""
targets = set()
while len(targets) < m:
x = rng.choice(seq)
targets.add(x)
return targets
# def _random_subset(seq, m, rng):
# """Return m unique elements from seq.
#
# This differs from random.sample which can return repeated
# elements if seq holds repeated elements.
#
# Note: rng is a random.Random or numpy.random.RandomState instance.
# """
# targets = set()
# while len(targets) < m:
# x = rng.choice(seq)
# targets.add(x)
# return targets


class ScaleFreeBA(TwoEndConnector):
@@ -455,11 +474,26 @@ class ScaleFreeBA(TwoEndConnector):
super(ScaleFreeBA, self).__init__()
self.m = m
self.directed = directed
self.seed = seed
self.rng = np.random.RandomState(seed)
self.seed = format_seed(seed)
self.rng = np.random.RandomState(self.seed)
rng = np.random if SUPPORT_NUMBA else self.rng

def _random_subset(seq, m):
targets = set()
while len(targets) < m:
x = rng.choice(seq)
targets.add(x)
return targets

self._connect = numba_jit(_random_subset)

def build_conn(self):
assert self.pre_num == self.post_num

# seed
self.seed = self.rng.randint(1, int(1e7))
numba_seed(self.seed)

num_node = self.pre_num
if self.m < 1 or self.m >= num_node:
raise ConnectorError(f"Barabási–Albert network must have m >= 1 and "
@@ -485,7 +519,7 @@ class ScaleFreeBA(TwoEndConnector):
repeated_nodes.extend([source] * self.m)
# Now choose m unique nodes from the existing nodes
# Pick uniformly from repeated_nodes (preferential attachment)
targets = list(_random_subset(repeated_nodes, self.m, self.rng))
targets = list(self._connect(np.asarray(repeated_nodes), self.m))
source += 1

return 'mat', conn
@@ -526,11 +560,25 @@ class ScaleFreeBADual(TwoEndConnector):
self.m2 = m2
self.p = p
self.directed = directed
self.seed = seed
self.rng = np.random.RandomState(seed=seed)
self.seed = format_seed(seed)
self.rng = np.random.RandomState(self.seed)
rng = np.random if SUPPORT_NUMBA else self.rng

def _random_subset(seq, m):
targets = set()
while len(targets) < m:
x = rng.choice(seq)
targets.add(x)
return targets

self._connect = numba_jit(_random_subset)

def build_conn(self):
assert self.pre_num == self.post_num
# seed
self.seed = self.rng.randint(1, int(1e7))
numba_seed(self.seed)

num_node = self.pre_num
if self.m1 < 1 or self.m1 >= num_node:
raise ConnectorError(f"Dual Barabási–Albert network must have m1 >= 1 and m1 < num_node, "
@@ -565,7 +613,7 @@ class ScaleFreeBADual(TwoEndConnector):
m = self.m1 if self.rng.random() < self.p else self.m2
# Now choose m unique nodes from the existing nodes
# Pick uniformly from repeated_nodes (preferential attachment)
targets = list(_random_subset(repeated_nodes, m, self.rng))
targets = list(self._connect(np.asarray(repeated_nodes), m))
source += 1

return 'mat', conn
@@ -622,11 +670,24 @@ class PowerLaw(TwoEndConnector):
if self.p > 1 or self.p < 0:
raise ConnectorError(f"p must be in [0,1], while p={self.p}")
self.directed = directed
self.seed = seed
self.rng = np.random.RandomState(seed)
self.seed = format_seed(seed)
self.rng = np.random.RandomState(self.seed)
rng = np.random if SUPPORT_NUMBA else self.rng

def _random_subset(seq, m):
targets = set()
while len(targets) < m:
x = rng.choice(seq)
targets.add(x)
return targets

self._connect = numba_jit(_random_subset)

def build_conn(self):
assert self.pre_num == self.post_num
# seed
self.seed = self.rng.randint(1, int(1e7))
numba_seed(self.seed)
num_node = self.pre_num
if self.m < 1 or num_node < self.m:
raise ConnectorError(f"Must have m>1 and m<n, while m={self.m} and n={num_node}")
@@ -636,7 +697,7 @@ class PowerLaw(TwoEndConnector):
# with nodes repeated once for each adjacent edge
source = self.m # next node is m
while source < num_node: # Now add the other n-1 nodes
possible_targets = _random_subset(repeated_nodes, self.m, self.rng)
possible_targets = self._connect(np.asarray(repeated_nodes), self.m)
# do one preferential attachment for new node
target = possible_targets.pop()
conn[source, target] = True
@@ -668,3 +729,169 @@ class PowerLaw(TwoEndConnector):

return 'mat', conn


@numba_jit
def pos2ind(pos, size):
idx = 0
for i, p in enumerate(pos):
idx += p * np.prod(size[i + 1:])
return idx


class ProbDist(TwoEndConnector):
"""Connection with a maximum distance under a probability `p`.

.. versionadded:: 2.1.13

Parameters
----------
dist: float, int
The maximum distance between two points.
prob: float
The connection probability, within 0. and 1.
pre_ratio: float
The ratio of pre-synaptic neurons to connect.
seed: optional, int
The random seed.
include_self: bool
Whether include the point at the same position.

"""

def __init__(self, dist=1, prob=1., pre_ratio=1., seed=None, include_self=True):
super(ProbDist, self).__init__()

self.prob = prob
self.pre_ratio = pre_ratio
self.dist = dist
self.seed = format_seed(seed)
self.rng = np.random.RandomState(self.seed)
self.include_self = include_self

rng = np.random if SUPPORT_NUMBA else self.rng

def _connect_1d(pre_pos, pre_size, post_size, n_dim):
all_post_ids = []
all_pre_ids = []
if rng.random() < pre_ratio:
normalized_pos = []
for i in range(n_dim):
pre_len = pre_size[i]
post_len = post_size[i]
normalized_pos.append(pre_pos[i] * post_len / pre_len)
for i in range(post_size[0]):
post_pos = np.asarray((i,))
d = np.sum(np.abs(pre_pos - post_pos))
if d <= dist:
if d == 0. and not include_self:
continue
if rng.random() <= prob:
all_post_ids.append(pos2ind(post_pos, post_size))
all_pre_ids.append(pos2ind(pre_pos, pre_size))
return all_pre_ids, all_post_ids

def _connect_2d(pre_pos, pre_size, post_size, n_dim):
all_post_ids = []
all_pre_ids = []
if rng.random() < pre_ratio:
normalized_pos = []
for i in range(n_dim):
pre_len = pre_size[i]
post_len = post_size[i]
normalized_pos.append(pre_pos[i] * post_len / pre_len)
for i in range(post_size[0]):
for j in range(post_size[1]):
post_pos = np.asarray((i, j))
d = np.sqrt(np.sum(np.square(pre_pos - post_pos)))
if d <= dist:
if d == 0. and not include_self:
continue
if np.random.random() <= prob:
all_post_ids.append(pos2ind(post_pos, post_size))
all_pre_ids.append(pos2ind(pre_pos, pre_size))
return all_pre_ids, all_post_ids

def _connect_3d(pre_pos, pre_size, post_size, n_dim):
all_post_ids = []
all_pre_ids = []
if rng.random() < pre_ratio:
normalized_pos = []
for i in range(n_dim):
pre_len = pre_size[i]
post_len = post_size[i]
normalized_pos.append(pre_pos[i] * post_len / pre_len)
for i in range(post_size[0]):
for j in range(post_size[1]):
for k in range(post_size[2]):
post_pos = np.asarray((i, j, k))
d = np.sqrt(np.sum(np.square(pre_pos - post_pos)))
if d <= dist:
if d == 0. and not include_self:
continue
if np.random.random() <= prob:
all_post_ids.append(pos2ind(post_pos, post_size))
all_pre_ids.append(pos2ind(pre_pos, pre_size))
return all_pre_ids, all_post_ids

def _connect_4d(pre_pos, pre_size, post_size, n_dim):
all_post_ids = []
all_pre_ids = []
if rng.random() < pre_ratio:
normalized_pos = []
for i in range(n_dim):
pre_len = pre_size[i]
post_len = post_size[i]
normalized_pos.append(pre_pos[i] * post_len / pre_len)
for i in range(post_size[0]):
for j in range(post_size[1]):
for k in range(post_size[2]):
for l in range(post_size[3]):
post_pos = np.asarray((i, j, k, l))
d = np.sqrt(np.sum(np.square(pre_pos - post_pos)))
if d <= dist:
if d == 0. and not include_self:
continue
if np.random.random() <= prob:
all_post_ids.append(pos2ind(post_pos, post_size))
all_pre_ids.append(pos2ind(pre_pos, pre_size))
return all_pre_ids, all_post_ids

self._connect_1d = numba_jit(_connect_1d)
self._connect_2d = numba_jit(_connect_2d)
self._connect_3d = numba_jit(_connect_3d)
self._connect_4d = numba_jit(_connect_4d)

def build_conn(self):
if len(self.pre_size) != len(self.post_size):
raise ValueError('The dimensions of shapes of two objects to establish connections should '
f'be the same. But we got dimension {len(self.pre_size)} != {len(self.post_size)}. '
f'Specifically, pre size = {self.pre_size}, post size = {self.post_size}')
self.seed = self.rng.randint(1, int(1e7))
numba_seed(self.seed)

# connections
n_dim = len(self.pre_size)
if n_dim == 1:
f = self._connect_1d
elif n_dim == 2:
f = self._connect_2d
elif n_dim == 3:
f = self._connect_3d
elif n_dim == 4:
f = self._connect_4d
else:
raise NotImplementedError('Does not support the network dimension bigger than 4.')

pre_size = np.asarray(self.pre_size)
post_size = np.asarray(self.post_size)
connected_pres = []
connected_posts = []
pre_ids = np.meshgrid(*(np.arange(p) for p in self.pre_size))
pre_ids = tuple([(np.moveaxis(p, 0, 1).flatten()) if p.ndim > 1 else p.flatten() for p in pre_ids])
size = np.prod(pre_size)
for i in range(size):
pre_pos = np.asarray([p[i] for p in pre_ids])
pres, posts = f(pre_pos, pre_size=pre_size, post_size=post_size, n_dim=n_dim)
connected_pres.extend(pres)
connected_posts.extend(posts)
return 'ij', (np.asarray(connected_pres), np.asarray(connected_posts))

+ 4
- 2
brainpy/connect/regular_conn.py View File

@@ -5,6 +5,7 @@ import logging
import numpy as np

from brainpy.errors import ConnectorError
from brainpy.tools.others import numba_jit

from .base import *

@@ -66,7 +67,7 @@ class All2All(TwoEndConnector):
all2all = All2All(include_self=True)


# @tools.numba_jit
@numba_jit
def _grid_four(height, width, row, include_self):
conn_i = []
conn_j = []
@@ -122,10 +123,11 @@ class GridFour(OneEndConnector):

return 'ij', (pre_ids, post_ids)


grid_four = GridFour()


# @tools.numba_jit
@numba_jit
def _grid_n(height, width, row, n, include_self):
conn_i = []
conn_j = []


+ 5
- 4
brainpy/connect/tests/test_random_conn.py View File

@@ -110,12 +110,12 @@ def test_gaussian_prob4():


def test_SmallWorld1():
conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=False)
conn(pre_size=10, post_size=10)
conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=False)
conn(pre_size=10, post_size=10)

mat = conn.require(bp.connect.CONN_MAT)
mat = conn.require(bp.connect.CONN_MAT)

print('conn_mat', mat)
print('conn_mat', mat)


def test_SmallWorld3():
@@ -126,6 +126,7 @@ def test_SmallWorld3():

print('conn_mat', mat)


def test_SmallWorld2():
conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5)
conn(pre_size=(100,), post_size=(100,))


+ 2
- 1
brainpy/datasets/__init__.py View File

@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-

from .chaotic_systems import *
from .chaos import *
from .vision import *


+ 239
- 0
brainpy/datasets/_internally_replaced_utils.py View File

@@ -0,0 +1,239 @@
# -*- coding: utf-8 -*-

import ctypes
import errno
import hashlib
import importlib.machinery
import os
import re
import shutil
import sys
import tempfile
import warnings
import zipfile
from urllib.parse import urlparse
from urllib.request import urlopen, Request
from brainpy import math as bm

from tqdm import tqdm

ENV_TORCH_HOME = 'BRAINPY_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'


def _get_torch_home():
torch_home = os.path.expanduser(
os.getenv(ENV_TORCH_HOME, os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'brainpy')))
return torch_home


# matches bfd8deac from resnet18-bfd8deac.pth
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
_HOME = os.path.join(_get_torch_home(), "datasets", "vision")
_USE_SHARDED_DATASETS = False


def _download_file_from_remote_location(fpath: str, url: str) -> None:
pass


def _is_remote_location_available() -> bool:
return False


def get_dir():
r"""
Get the Torch Hub cache directory used for storing downloaded models & weights.

If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where
environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
filesystem layout, with a default value ``~/.cache`` if the environment
variable is not set.
"""
# Issue warning to move data if old env is set
return os.path.join(_get_torch_home(), 'hub')


def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None):
r"""Loads the Torch serialized object at the given URL.

If downloaded file is a zip file, it will be automatically
decompressed.

If the object is already present in `model_dir`, it's deserialized and
returned.
The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where
``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`.

Args:
url (string): URL of the object to download
model_dir (string, optional): directory in which to save the object
map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
progress (bool, optional): whether or not to display a progress bar to stderr.
Default: True
check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
digits of the SHA256 hash of the contents of the file. The hash is used to
ensure unique names and to verify the contents of the file.
Default: False
file_name (string, optional): name for the downloaded file. Filename from ``url`` will be used if not set.

Example:
>>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')

"""
# Issue warning to move data if old env is set
if os.getenv('TORCH_MODEL_ZOO'):
warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')

if model_dir is None:
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')

try:
os.makedirs(model_dir)
except OSError as e:
if e.errno == errno.EEXIST:
# Directory already exists, ignore.
pass
else:
# Unexpected OSError, re-raise.
raise

parts = urlparse(url)
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix = None
if check_hash:
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
hash_prefix = r.group(1) if r else None
download_url_to_file(url, cached_file, hash_prefix, progress=progress)

if _is_legacy_zip_format(cached_file):
return _legacy_zip_load(cached_file, model_dir, map_location)
return bm.load(cached_file, map_location=map_location)


def _legacy_zip_load(filename, model_dir, map_location):
warnings.warn('Falling back to the old format < 1.6. This support will be '
'deprecated in favor of default zipfile format introduced in 1.6. '
'Please redo torch.save() to save it in the new zipfile format.')
# Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
# We deliberately don't handle tarfile here since our legacy serialization format was in tar.
# E.g. resnet18-5c106cde.pth which is widely used.
with zipfile.ZipFile(filename) as f:
members = f.infolist()
if len(members) != 1:
raise RuntimeError('Only one file(not dir) is allowed in the zipfile')
f.extractall(model_dir)
extraced_name = members[0].filename
extracted_file = os.path.join(model_dir, extraced_name)
return bm.load(extracted_file, map_location=map_location)


# Hub used to support automatically extracts from zipfile manually compressed by users.
# The legacy zip format expects only one file from torch.save() < 1.6 in the zip.
# We should remove this support since zipfile is now default zipfile format for torch.save().
def _is_legacy_zip_format(filename):
if zipfile.is_zipfile(filename):
infolist = zipfile.ZipFile(filename).infolist()
return len(infolist) == 1 and not infolist[0].is_dir()
return False


def download_url_to_file(url, dst, hash_prefix=None, progress=True):
r"""Download object at the given URL to a local path.

Args:
url (string): URL of the object to download
dst (string): Full path where object will be saved, e.g. ``/tmp/temporary_file``
hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``.
Default: None
progress (bool, optional): whether or not to display a progress bar to stderr
Default: True

Example:
>>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')

"""
file_size = None
req = Request(url, headers={"User-Agent": "torch.hub"})
u = urlopen(req)
meta = u.info()
if hasattr(meta, 'getheaders'):
content_length = meta.getheaders("Content-Length")
else:
content_length = meta.get_all("Content-Length")
if content_length is not None and len(content_length) > 0:
file_size = int(content_length[0])

# We deliberately save it in a temp file and move it after
# download is complete. This prevents a local working checkpoint
# being overridden by a broken download.
dst = os.path.expanduser(dst)
dst_dir = os.path.dirname(dst)
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)

try:
if hash_prefix is not None:
sha256 = hashlib.sha256()
with tqdm(total=file_size, disable=not progress,
unit='B', unit_scale=True, unit_divisor=1024) as pbar:
while True:
buffer = u.read(8192)
if len(buffer) == 0:
break
f.write(buffer)
if hash_prefix is not None:
sha256.update(buffer)
pbar.update(len(buffer))

f.close()
if hash_prefix is not None:
digest = sha256.hexdigest()
if digest[:len(hash_prefix)] != hash_prefix:
raise RuntimeError('invalid hash value (expected "{}", got "{}")'
.format(hash_prefix, digest))
shutil.move(f.name, dst)
finally:
f.close()
if os.path.exists(f.name):
os.remove(f.name)


def _get_extension_path(lib_name):
lib_dir = os.path.dirname(__file__)
if os.name == "nt":
# Register the main torchvision library location on the default DLL path
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
prev_error_mode = kernel32.SetErrorMode(0x0001)

if with_load_library_flags:
kernel32.AddDllDirectory.restype = ctypes.c_void_p

if sys.version_info >= (3, 8):
os.add_dll_directory(lib_dir)
elif with_load_library_flags:
res = kernel32.AddDllDirectory(lib_dir)
if res is None:
err = ctypes.WinError(ctypes.get_last_error())
err.strerror += f' Error adding "{lib_dir}" to the DLL directories.'
raise err

kernel32.SetErrorMode(prev_error_mode)

loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES)

extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
ext_specs = extfinder.find_spec(lib_name)
if ext_specs is None:
raise ImportError

return ext_specs.origin

+ 264
- 0
brainpy/datasets/base.py View File

@@ -0,0 +1,264 @@
# -*- coding: utf-8 -*-


import bisect
import warnings
from typing import Any
from typing import Callable, Generic, Iterable, Iterator, List, Optional, Tuple, TypeVar

T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')


__all__ = [
'Dataset',
'IterableDataset',
'ChainDataset',
'StandardTransform'
]


class Dataset(Generic[T_co]):
r"""An abstract class representing a :class:`Dataset`.

All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~.Sampler` implementations and the default options
of :class:`~.DataLoader`.

.. note::
:class:`~.DataLoader` by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""

def __getitem__(self, index) -> T_co:
raise NotImplementedError

def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other])

# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
# in pytorch/torch/utils/data/sampler.py


class IterableDataset(Dataset[T_co]):
r"""An iterable Dataset.

All datasets that represent an iterable of data samples should subclass it.
Such form of datasets is particularly useful when data come from a stream.

All subclasses should overwrite :meth:`__iter__`, which would return an
iterator of samples in this dataset.

When a subclass is used with :class:`~.DataLoader`, each
item in the dataset will be yielded from the :class:`~.DataLoader`
iterator. When :attr:`num_workers > 0`, each worker process will have a
different copy of the dataset object, so it is often desired to configure
each copy independently to avoid having duplicate data returned from the
workers. :func:`~.get_worker_info`, when called in a worker
process, returns information about the worker. It can be used in either the
dataset's :meth:`__iter__` method or the :class:`~.DataLoader` 's
:attr:`worker_init_fn` option to modify each copy's behavior.

Example 1: splitting workload across all workers in :meth:`__iter__`::

>>> class MyIterableDataset(.IterableDataset):
... def __init__(self, start, end):
... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start"
... self.start = start
... self.end = end
...
... def __iter__(self):
... worker_info = .get_worker_info()
... if worker_info is None: # single-process data loading, return the full iterator
... iter_start = self.start
... iter_end = self.end
... else: # in a worker process
... # split workload
... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
... worker_id = worker_info.id
... iter_start = self.start + worker_id * per_worker
... iter_end = min(iter_start + per_worker, self.end)
... return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]

>>> # Mult-process loading with two worker processes
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
>>> print(list(.DataLoader(ds, num_workers=2)))
[3, 5, 4, 6]

>>> # With even more workers
>>> print(list(.DataLoader(ds, num_workers=20)))
[3, 4, 5, 6]

Example 2: splitting workload across all workers using :attr:`worker_init_fn`::

>>> class MyIterableDataset(.IterableDataset):
... def __init__(self, start, end):
... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start"
... self.start = start
... self.end = end
...
... def __iter__(self):
... return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]

>>> # Define a `worker_init_fn` that configures each dataset copy differently
>>> def worker_init_fn(worker_id):
... worker_info = .get_worker_info()
... dataset = worker_info.dataset # the dataset copy in this worker process
... overall_start = dataset.start
... overall_end = dataset.end
... # configure the dataset to only process the split workload
... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
... worker_id = worker_info.id
... dataset.start = overall_start + worker_id * per_worker
... dataset.end = min(dataset.start + per_worker, overall_end)
...

>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
>>> print(list(.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]

>>> # With even more workers
>>> print(list(.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]
"""

def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError

def __add__(self, other: Dataset[T_co]):
return ChainDataset([self, other])

# No `def __len__(self)` default? Subclasses raise `TypeError` when needed.
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]


class ChainDataset(IterableDataset):
r"""Dataset for chaining multiple :class:`IterableDataset` s.

This class is useful to assemble different existing dataset streams. The
chaining operation is done on-the-fly, so concatenating large-scale
datasets with this class will be efficient.

Args:
datasets (iterable of IterableDataset): datasets to be chained together
"""

def __init__(self, datasets: Iterable[Dataset]) -> None:
super(ChainDataset, self).__init__()
self.datasets = datasets

def __iter__(self):
for d in self.datasets:
assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
for x in d:
yield x

def __len__(self):
total = 0
for d in self.datasets:
assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
total += len(d)
return total


class ConcatDataset(Dataset[T_co]):
r"""Dataset as a concatenation of multiple datasets.

This class is useful to assemble different existing datasets.

Args:
datasets (sequence): List of datasets to be concatenated
"""
datasets: List[Dataset[T_co]]
cumulative_sizes: List[int]

@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r

def __init__(self, datasets: Iterable[Dataset]) -> None:
super(ConcatDataset, self).__init__()
self.datasets = list(datasets)
assert len(self.datasets) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type]
for d in self.datasets:
assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
self.cumulative_sizes = self.cumsum(self.datasets)

def __len__(self):
return self.cumulative_sizes[-1]

def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx]

@property
def cummulative_sizes(self):
warnings.warn("cummulative_sizes attribute is renamed to "
"cumulative_sizes", DeprecationWarning, stacklevel=2)
return self.cumulative_sizes


class StandardTransform:
def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None:
self.transform = transform
self.target_transform = target_transform

def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]:
if self.transform is not None:
input = self.transform(input)
if self.target_transform is not None:
target = self.target_transform(target)
return input, target

def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
lines = transform.__repr__().splitlines()
return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]

def __repr__(self) -> str:
body = [self.__class__.__name__]
if self.transform is not None:
body += self._format_transform_repr(self.transform, "Transform: ")
if self.target_transform is not None:
body += self._format_transform_repr(self.target_transform, "Target transform: ")

return "\n".join(body)


+ 5
- 0
brainpy/datasets/chaos/__init__.py View File

@@ -0,0 +1,5 @@
# -*- coding: utf-8 -*-


from .chaotic_systems import *


brainpy/datasets/chaotic_systems.py → brainpy/datasets/chaos/chaotic_systems.py View File

@@ -3,7 +3,7 @@
import jax.numpy as jnp

from brainpy import math as bm, dyn
from brainpy.integrators import odeint, ddeint, JointEq, IntegratorRunner
from brainpy.integrators import odeint, JointEq, IntegratorRunner

__all__ = [
'henon_map_series',
@@ -164,7 +164,7 @@ def mackey_glass_series(duration, dt=0.1, beta=2., gamma=1., tau=2., n=9.65,
if inits is None:
inits = bm.ones(1) * 1.2
elif isinstance(inits, (float, int)):
inits = bm.asarray([inits], dtype=bm.float_)
inits = bm.asarray([inits], dtype=bm.dftype())
else:
assert isinstance(inits, (bm.ndarray, jnp.ndarray))

@@ -172,8 +172,7 @@ def mackey_glass_series(duration, dt=0.1, beta=2., gamma=1., tau=2., n=9.65,
xdelay = bm.TimeDelay(inits, tau, dt=dt, interp_method='round')
xdelay.data.value = inits + 0.2 * (rng.random((xdelay.num_delay_step,) + inits.shape) - 0.5)

@ddeint(method=method,
state_delays={'x': xdelay})
@odeint(method=method, state_delays={'x': xdelay})
def mg_eq(x, t):
xtau = xdelay(t - tau)
return beta * xtau / (1 + xtau ** n) - gamma * x

+ 4
- 0
brainpy/datasets/vision/__init__.py View File

@@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-

from .mnist import *


+ 90
- 0
brainpy/datasets/vision/base.py View File

@@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-

import os
import os.path
from typing import Any
from typing import Callable, List, Optional

from ..base import Dataset, StandardTransform

__all__ = [
'VisionDataset'
]


class VisionDataset(Dataset):
"""
Base Class For making datasets which are compatible with torchvision.
It is necessary to override the ``__getitem__`` and ``__len__`` method.

Args:
root (string): Root directory of dataset.
transforms (callable, optional): A function/transforms that takes in
an image and a label and returns the transformed versions of both.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.

.. note::

:attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive.
"""

_repr_indent = 4

def __init__(
self,
root: str,
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
if isinstance(root, (str, bytes)):
root = os.path.expanduser(root)
self.root = root

has_transforms = transforms is not None
has_separate_transform = transform is not None or target_transform is not None
if has_transforms and has_separate_transform:
raise ValueError("Only transforms or transform/target_transform can be passed as argument")

# for backwards-compatibility
self.transform = transform
self.target_transform = target_transform

if has_separate_transform:
transforms = StandardTransform(transform, target_transform)
self.transforms = transforms

def __getitem__(self, index: int) -> Any:
"""
Args:
index (int): Index

Returns:
(Any): Sample and meta data, optionally transformed by the respective transforms.
"""
raise NotImplementedError

def __len__(self) -> int:
raise NotImplementedError

def __repr__(self) -> str:
head = "Dataset " + self.__class__.__name__
body = [f"Number of datapoints: {self.__len__()}"]
if self.root is not None:
body.append(f"Root location: {self.root}")
body += self.extra_repr().splitlines()
if hasattr(self, "transforms") and self.transforms is not None:
body += [repr(self.transforms)]
lines = [head] + [" " * self._repr_indent + line for line in body]
return "\n".join(lines)

def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
lines = transform.__repr__().splitlines()
return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]

def extra_repr(self) -> str:
return ""


+ 561
- 0
brainpy/datasets/vision/mnist.py View File

@@ -0,0 +1,561 @@
# -*- coding: utf-8 -*-

import codecs
import os
import os.path
import shutil
import string
import sys
import warnings
from typing import Any
from typing import Callable, Dict, List, Optional, Tuple
from urllib.error import URLError
from brainpy.errors import PackageMissingError

import jax.numpy as jnp
import numpy as np
try:
from PIL import Image
except ImportError:
Image = None

import brainpy.math as bm
from .base import VisionDataset
from .utils import download_and_extract_archive, extract_archive, verify_str_arg, check_integrity

__all__ = [
'MNIST',
'FashionMNIST',
'KMNIST',
'EMNIST',
'QMNIST',
]


class MNIST(VisionDataset):
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.

Args:
root (string): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte``
and ``MNIST/raw/t10k-images-idx3-ubyte`` exist.
train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
otherwise from ``t10k-images-idx3-ubyte``.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""

mirrors = [
"http://yann.lecun.com/exdb/mnist/",
"https://ossci-datasets.s3.amazonaws.com/mnist/",
]

resources = [
("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"),
]

training_file = "training.pt"
test_file = "test.pt"
classes = [
"0 - zero",
"1 - one",
"2 - two",
"3 - three",
"4 - four",
"5 - five",
"6 - six",
"7 - seven",
"8 - eight",
"9 - nine",
]

@property
def train_labels(self):
warnings.warn("train_labels has been renamed targets")
return self.targets

@property
def test_labels(self):
warnings.warn("test_labels has been renamed targets")
return self.targets

@property
def train_data(self):
warnings.warn("train_data has been renamed data")
return self.data

@property
def test_data(self):
warnings.warn("test_data has been renamed data")
return self.data

def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self.train = train # training set or test set

if self._check_legacy_exist():
self.data, self.targets = self._load_legacy_data()
return

if download:
self.download()

if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")

self.data, self.targets = self._load_data()

def _check_legacy_exist(self):
processed_folder_exists = os.path.exists(self.processed_folder)
if not processed_folder_exists:
return False

return all(
check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file)
)

def _load_legacy_data(self):
# This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
# directly.
data_file = self.training_file if self.train else self.test_file
return jnp.load(os.path.join(self.processed_folder, data_file))

def _load_data(self):
image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
data = read_image_file(os.path.join(self.raw_folder, image_file))

label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
targets = read_label_file(os.path.join(self.raw_folder, label_file))

return data, targets

def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index

Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], int(self.targets[index])

# doing this so that it is consistent with all other datasets
# to return a PIL Image
if Image is None:
raise PackageMissingError('Need pillow to read the image, pleas install pillow first.')
img = Image.fromarray(img.numpy(), mode="L")

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def __len__(self) -> int:
return len(self.data)

@property
def raw_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, "raw")

@property
def processed_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, "processed")

@property
def class_to_idx(self) -> Dict[str, int]:
return {_class: i for i, _class in enumerate(self.classes)}

def _check_exists(self) -> bool:
return all(
check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]))
for url, _ in self.resources
)

def download(self) -> None:
"""Download the MNIST data if it doesn't exist already."""

if self._check_exists():
return

os.makedirs(self.raw_folder, exist_ok=True)

# download files
for filename, md5 in self.resources:
for mirror in self.mirrors:
url = f"{mirror}{filename}"
try:
print(f"Downloading {url}")
download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
except URLError as error:
print(f"Failed to download (trying next):\n{error}")
continue
finally:
print()
break
else:
raise RuntimeError(f"Error downloading {filename}")

def extra_repr(self) -> str:
split = "Train" if self.train is True else "Test"
return f"Split: {split}"


class FashionMNIST(MNIST):
"""`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.

Args:
root (string): Root directory of dataset where ``FashionMNIST/raw/train-images-idx3-ubyte``
and ``FashionMNIST/raw/t10k-images-idx3-ubyte`` exist.
train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
otherwise from ``t10k-images-idx3-ubyte``.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""

mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"]

resources = [
("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"),
("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"),
("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"),
("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310"),
]
classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]


class KMNIST(MNIST):
"""`Kuzushiji-MNIST <https://github.com/rois-codh/kmnist>`_ Dataset.

Args:
root (string): Root directory of dataset where ``KMNIST/raw/train-images-idx3-ubyte``
and ``KMNIST/raw/t10k-images-idx3-ubyte`` exist.
train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
otherwise from ``t10k-images-idx3-ubyte``.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""

mirrors = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"]

resources = [
("train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"),
("train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"),
("t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"),
("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134"),
]
classes = ["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"]


class EMNIST(MNIST):
"""`EMNIST <https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist>`_ Dataset.

Args:
root (string): Root directory of dataset where ``EMNIST/raw/train-images-idx3-ubyte``
and ``EMNIST/raw/t10k-images-idx3-ubyte`` exist.
split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
which one to use.
train (bool, optional): If True, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""

url = "https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip"
md5 = "58c8d27c78d21e728a6bc7b3cc06412e"
splits = ("byclass", "bymerge", "balanced", "letters", "digits", "mnist")
# Merged Classes assumes Same structure for both uppercase and lowercase version
_merged_classes = {"c", "i", "j", "k", "l", "m", "o", "p", "s", "u", "v", "w", "x", "y", "z"}
_all_classes = set(string.digits + string.ascii_letters)
classes_split_dict = {
"byclass": sorted(list(_all_classes)),
"bymerge": sorted(list(_all_classes - _merged_classes)),
"balanced": sorted(list(_all_classes - _merged_classes)),
"letters": ["N/A"] + list(string.ascii_lowercase),
"digits": list(string.digits),
"mnist": list(string.digits),
}

def __init__(self, root: str, split: str, **kwargs: Any) -> None:
self.split = verify_str_arg(split, "split", self.splits)
self.training_file = self._training_file(split)
self.test_file = self._test_file(split)
super().__init__(root, **kwargs)
self.classes = self.classes_split_dict[self.split]

@staticmethod
def _training_file(split) -> str:
return f"training_{split}.pt"

@staticmethod
def _test_file(split) -> str:
return f"test_{split}.pt"

@property
def _file_prefix(self) -> str:
return f"emnist-{self.split}-{'train' if self.train else 'test'}"

@property
def images_file(self) -> str:
return os.path.join(self.raw_folder, f"{self._file_prefix}-images-idx3-ubyte")

@property
def labels_file(self) -> str:
return os.path.join(self.raw_folder, f"{self._file_prefix}-labels-idx1-ubyte")

def _load_data(self):
return read_image_file(self.images_file), read_label_file(self.labels_file)

def _check_exists(self) -> bool:
return all(check_integrity(file) for file in (self.images_file, self.labels_file))

def download(self) -> None:
"""Download the EMNIST data if it doesn't exist already."""

if self._check_exists():
return

os.makedirs(self.raw_folder, exist_ok=True)

download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5)
gzip_folder = os.path.join(self.raw_folder, "gzip")
for gzip_file in os.listdir(gzip_folder):
if gzip_file.endswith(".gz"):
extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder)
shutil.rmtree(gzip_folder)


class QMNIST(MNIST):
"""`QMNIST <https://github.com/facebookresearch/qmnist>`_ Dataset.

Args:
root (string): Root directory of dataset whose ``raw``
subdir contains binary files of the datasets.
what (string,optional): Can be 'train', 'test', 'test10k',
'test50k', or 'nist' for respectively the mnist compatible
training set, the 60k qmnist testing set, the 10k qmnist
examples that match the mnist testing set, the 50k
remaining qmnist testing examples, or all the nist
digits. The default is to select 'train' or 'test'
according to the compatibility argument 'train'.
compat (bool,optional): A boolean that says whether the target
for each example is class number (for compatibility with
the MNIST dataloader) or a torch vector containing the
full qmnist information. Default=True.
download (bool, optional): If True, downloads the dataset from
the internet and puts it in root directory. If dataset is
already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that
takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform
that takes in the target and transforms it.
train (bool,optional,compatibility): When argument 'what' is
not specified, this boolean decides whether to load the
training set ot the testing set. Default: True.
"""

subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"}
resources: Dict[str, List[Tuple[str, str]]] = { # type: ignore[assignment]
"train": [
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz",
"ed72d4157d28c017586c42bc6afe6370",
),
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz",
"0058f8dd561b90ffdd0f734c6a30e5e4",
),
],
"test": [
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz",
"1394631089c404de565df7b7aeaf9412",
),
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz",
"5b5b05890a5e13444e108efe57b788aa",
),
],
"nist": [
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz",
"7f124b3b8ab81486c9d8c2749c17f834",
),
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz",
"5ed0e788978e45d4a8bd4b7caec3d79d",
),
],
}
classes = [
"0 - zero",
"1 - one",
"2 - two",
"3 - three",
"4 - four",
"5 - five",
"6 - six",
"7 - seven",
"8 - eight",
"9 - nine",
]

def __init__(
self, root: str, what: Optional[str] = None, compat: bool = True, train: bool = True, **kwargs: Any
) -> None:
if what is None:
what = "train" if train else "test"
self.what = verify_str_arg(what, "what", tuple(self.subsets.keys()))
self.compat = compat
self.data_file = what + ".pt"
self.training_file = self.data_file
self.test_file = self.data_file
super().__init__(root, train, **kwargs)

@property
def images_file(self) -> str:
(url, _), _ = self.resources[self.subsets[self.what]]
return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])

@property
def labels_file(self) -> str:
_, (url, _) = self.resources[self.subsets[self.what]]
return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])

def _check_exists(self) -> bool:
return all(check_integrity(file) for file in (self.images_file, self.labels_file))

def _load_data(self):
data = read_sn3_pascalvincent_tensor(self.images_file)
assert data.dtype == jnp.uint8
assert data.ndim == 3

targets = read_sn3_pascalvincent_tensor(self.labels_file).long()
assert targets.ndimension() == 2

if self.what == "test10k":
data = data[0:10000, :, :].clone()
targets = targets[0:10000, :].clone()
elif self.what == "test50k":
data = data[10000:, :, :].clone()
targets = targets[10000:, :].clone()

return data, targets

def download(self) -> None:
"""Download the QMNIST data if it doesn't exist already.
Note that we only download what has been asked for (argument 'what').
"""
if self._check_exists():
return

os.makedirs(self.raw_folder, exist_ok=True)
split = self.resources[self.subsets[self.what]]

for url, md5 in split:
download_and_extract_archive(url, self.raw_folder, md5=md5)

def __getitem__(self, index: int) -> Tuple[Any, Any]:
# redefined to handle the compat flag
img, target = self.data[index], self.targets[index]
if Image is None:
raise PackageMissingError('Need pillow to read the image, pleas install pillow first.')
img = Image.fromarray(img.numpy(), mode="L")
if self.transform is not None:
img = self.transform(img)
if self.compat:
target = int(target[0])
if self.target_transform is not None:
target = self.target_transform(target)
return img, target

def extra_repr(self) -> str:
return f"Split: {self.what}"


def get_int(b: bytes) -> int:
return int(codecs.encode(b, "hex"), 16)


SN3_PASCALVINCENT_TYPEMAP = {
8: jnp.uint8,
9: jnp.int8,
11: jnp.int16,
12: jnp.int32,
13: jnp.float32,
14: jnp.float64,
}


def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> jnp.ndarray:
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
Argument may be a filename, compressed filename, or file object.
"""
# read
with open(path, "rb") as f:
data = f.read()
# parse
magic = get_int(data[0:4])
nd = magic % 256
ty = magic // 256
assert 1 <= nd <= 3
assert 8 <= ty <= 14
dtype = SN3_PASCALVINCENT_TYPEMAP[ty]
s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)]

num_bytes_per_value = jnp.iinfo(dtype).bits // 8
# The MNIST format uses the big endian byte order. If the system uses little endian byte order by default,
# we need to reverse the bytes before we can read them with .frombuffer().
needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1
parsed = jnp.frombuffer(bytearray(data), dtype=dtype, offset=(4 * (nd + 1)))
if needs_byte_reversal:
parsed = jnp.flip(parsed, 0)
assert parsed.shape[0] == np.prod(s) or not strict
return parsed.reshape(*s)


def read_label_file(path: str) -> jnp.ndarray:
x = read_sn3_pascalvincent_tensor(path, strict=False)
assert x.dtype == jnp.uint8
assert x.ndim == 1
return x.astype(bm.dftype())


def read_image_file(path: str) -> jnp.ndarray:
x = read_sn3_pascalvincent_tensor(path, strict=False)
assert x.dtype == jnp.uint8
assert x.ndim == 3
return x

+ 479
- 0
brainpy/datasets/vision/utils.py View File

@@ -0,0 +1,479 @@
# -*- coding: utf-8 -*-

import bz2
import gzip
import hashlib
import itertools
import lzma
import os
import os.path
import pathlib
import re
import tarfile
import urllib
import urllib.error
import urllib.request
import zipfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator
from urllib.parse import urlparse

from brainpy.errors import PackageMissingError

try:
import requests
except ImportError:
requests = None
from tqdm import tqdm

from .._internally_replaced_utils import (
_download_file_from_remote_location,
_is_remote_location_available,
)

# import torch
# from torch.utils.model_zoo import tqdm

USER_AGENT = "pytorch/vision"


def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
with open(filename, "wb") as fh:
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
with tqdm(total=response.length) as pbar:
for chunk in iter(lambda: response.read(chunk_size), ""):
if not chunk:
break
pbar.update(chunk_size)
fh.write(chunk)


def gen_bar_updater() -> Callable[[int, int, int], None]:
pbar = tqdm(total=None)

def bar_update(count, block_size, total_size):
if pbar.total is None and total_size:
pbar.total = total_size
progress_bytes = count * block_size
pbar.update(progress_bytes - pbar.n)

return bar_update


def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
md5 = hashlib.md5()
with open(fpath, "rb") as f:
for chunk in iter(lambda: f.read(chunk_size), b""):
md5.update(chunk)
return md5.hexdigest()


def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool:
return md5 == calculate_md5(fpath, **kwargs)


def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
if not os.path.isfile(fpath):
return False
if md5 is None:
return True
return check_md5(fpath, md5)


def _get_redirect_url(url: str, max_hops: int = 3) -> str:
initial_url = url
headers = {"Method": "HEAD", "User-Agent": USER_AGENT}

for _ in range(max_hops + 1):
with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response:
if response.url == url or response.url is None:
return url

url = response.url
else:
raise RecursionError(
f"Request to {initial_url} exceeded {max_hops} redirects. The last redirect points to {url}."
)


def _get_google_drive_file_id(url: str) -> Optional[str]:
parts = urlparse(url)

if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
return None

match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
if match is None:
return None

return match.group("id")


def download_url(
url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3
) -> None:
"""Download a file from a url and place it in root.

Args:
url (str): URL to download file from
root (str): Directory to place downloaded file in
filename (str, optional): Name to save the file under. If None, use the basename of the URL
md5 (str, optional): MD5 checksum of the download. If None, do not check
max_redirect_hops (int, optional): Maximum number of redirect hops allowed
"""
root = os.path.expanduser(root)
if not filename:
filename = os.path.basename(url)
fpath = os.path.join(root, filename)

os.makedirs(root, exist_ok=True)

# check if file is already present locally
if check_integrity(fpath, md5):
print("Using downloaded and verified file: " + fpath)
return

if _is_remote_location_available():
_download_file_from_remote_location(fpath, url)
else:
# expand redirect chain if needed
url = _get_redirect_url(url, max_hops=max_redirect_hops)

# check if file is located on Google Drive
file_id = _get_google_drive_file_id(url)
if file_id is not None:
return download_file_from_google_drive(file_id, root, filename, md5)

# download the file
try:
print("Downloading " + url + " to " + fpath)
_urlretrieve(url, fpath)
except (urllib.error.URLError, OSError) as e: # type: ignore[attr-defined]
if url[:5] == "https":
url = url.replace("https:", "http:")
print("Failed download. Trying https -> http instead. Downloading " + url + " to " + fpath)
_urlretrieve(url, fpath)
else:
raise e

# check integrity of downloaded file
if not check_integrity(fpath, md5):
raise RuntimeError("File not found or corrupted.")


def list_dir(root: str, prefix: bool = False) -> List[str]:
"""List all directories at a given root

Args:
root (str): Path to directory whose folders need to be listed
prefix (bool, optional): If true, prepends the path to each result, otherwise
only returns the name of the directories found
"""
root = os.path.expanduser(root)
directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))]
if prefix is True:
directories = [os.path.join(root, d) for d in directories]
return directories


def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
"""List all files ending with a suffix at a given root

Args:
root (str): Path to directory whose folders need to be listed
suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
It uses the Python "str.endswith" method and is passed directly
prefix (bool, optional): If true, prepends the path to each result, otherwise
only returns the name of the files found
"""
root = os.path.expanduser(root)
files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)]
if prefix is True:
files = [os.path.join(root, d) for d in files]
return files


def _quota_exceeded(first_chunk: bytes) -> bool:
try:
return "Google Drive - Quota exceeded" in first_chunk.decode()
except UnicodeDecodeError:
return False


def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
"""Download a Google Drive file from and place it in root.

Args:
file_id (str): id of file to be downloaded
root (str): Directory to place downloaded file in
filename (str, optional): Name to save the file under. If None, use the id of the file.
md5 (str, optional): MD5 checksum of the download. If None, do not check
"""
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url

url = "https://docs.google.com/uc?export=download"

root = os.path.expanduser(root)
if not filename:
filename = file_id
fpath = os.path.join(root, filename)

os.makedirs(root, exist_ok=True)

if os.path.isfile(fpath) and check_integrity(fpath, md5):
print("Using downloaded and verified file: " + fpath)
else:
if requests is None:
raise PackageMissingError('Need "requests" package, please install it.')
session = requests.Session()

response = session.get(url, params={"id": file_id}, stream=True)
token = _get_confirm_token(response)

if token:
params = {"id": file_id, "confirm": token}
response = session.get(url, params=params, stream=True)

# Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent
# with their own API, refer https://github.com/pytorch/vision/issues/2992#issuecomment-730614517.
# Should this be fixed at some place in future, one could refactor the following to no longer rely on decoding
# the first_chunk of the payload
response_content_generator = response.iter_content(32768)
first_chunk = None
while not first_chunk: # filter out keep-alive new chunks
first_chunk = next(response_content_generator)

if _quota_exceeded(first_chunk):
msg = (
f"The daily quota of the file {filename} is exceeded and it "
f"can't be downloaded. This is a limitation of Google Drive "
f"and can only be overcome by trying again later."
)
raise RuntimeError(msg)

_save_response_content(itertools.chain((first_chunk,), response_content_generator), fpath)
response.close()


def _get_confirm_token(response) -> Optional[str]:
for key, value in response.cookies.items():
if key.startswith("download_warning"):
return value

return None


def _save_response_content(
response_gen: Iterator[bytes],
destination: str,
) -> None:
with open(destination, "wb") as f:
pbar = tqdm(total=None)
progress = 0

for chunk in response_gen:
if chunk: # filter out keep-alive new chunks
f.write(chunk)
progress += len(chunk)
pbar.update(progress - pbar.n)
pbar.close()


def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
tar.extractall(to_path)


_ZIP_COMPRESSION_MAP: Dict[str, int] = {
".bz2": zipfile.ZIP_BZIP2,
".xz": zipfile.ZIP_LZMA,
}


def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None:
with zipfile.ZipFile(
from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
) as zip:
zip.extractall(to_path)


_ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = {
".tar": _extract_tar,
".zip": _extract_zip,
}
_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = {
".bz2": bz2.open,
".gz": gzip.open,
".xz": lzma.open,
}
_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {
".tbz": (".tar", ".bz2"),
".tbz2": (".tar", ".bz2"),
".tgz": (".tar", ".gz"),
}


def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
"""Detect the archive type and/or compression of a file.

Args:
file (str): the filename

Returns:
(tuple): tuple of suffix, archive type, and compression

Raises:
RuntimeError: if file has no suffix or suffix is not supported
"""
suffixes = pathlib.Path(file).suffixes
if not suffixes:
raise RuntimeError(
f"File '{file}' has no suffixes that could be used to detect the archive type and compression."
)
suffix = suffixes[-1]

# check if the suffix is a known alias
if suffix in _FILE_TYPE_ALIASES:
return (suffix, *_FILE_TYPE_ALIASES[suffix])

# check if the suffix is an archive type
if suffix in _ARCHIVE_EXTRACTORS:
return suffix, suffix, None

# check if the suffix is a compression
if suffix in _COMPRESSED_FILE_OPENERS:
# check for suffix hierarchy
if len(suffixes) > 1:
suffix2 = suffixes[-2]

# check if the suffix2 is an archive type
if suffix2 in _ARCHIVE_EXTRACTORS:
return suffix2 + suffix, suffix2, suffix

return suffix, None, suffix

valid_suffixes = sorted(set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS))
raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.")


def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
r"""Decompress a file.

The compression is automatically detected from the file name.

Args:
from_path (str): Path to the file to be decompressed.
to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used.
remove_finished (bool): If ``True``, remove the file after the extraction.

Returns:
(str): Path to the decompressed file.
"""
suffix, archive_type, compression = _detect_file_type(from_path)
if not compression:
raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.")

if to_path is None:
to_path = from_path.replace(suffix, archive_type if archive_type is not None else "")

# We don't need to check for a missing key here, since this was already done in _detect_file_type()
compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression]

with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh:
wfh.write(rfh.read())

if remove_finished:
os.remove(from_path)

return to_path


def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
"""Extract an archive.

The archive type and a possible compression is automatically detected from the file name. If the file is compressed
but not an archive the call is dispatched to :func:`decompress`.

Args:
from_path (str): Path to the file to be extracted.
to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is
used.
remove_finished (bool): If ``True``, remove the file after the extraction.

Returns:
(str): Path to the directory the file was extracted to.
"""
if to_path is None:
to_path = os.path.dirname(from_path)

suffix, archive_type, compression = _detect_file_type(from_path)
if not archive_type:
return _decompress(
from_path,
os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")),
remove_finished=remove_finished,
)

# We don't need to check for a missing key here, since this was already done in _detect_file_type()
extractor = _ARCHIVE_EXTRACTORS[archive_type]

extractor(from_path, to_path, compression)
if remove_finished:
os.remove(from_path)

return to_path


def download_and_extract_archive(
url: str,
download_root: str,
extract_root: Optional[str] = None,
filename: Optional[str] = None,
md5: Optional[str] = None,
remove_finished: bool = False,
) -> None:
download_root = os.path.expanduser(download_root)
if extract_root is None:
extract_root = download_root
if not filename:
filename = os.path.basename(url)

download_url(url, download_root, filename, md5)

archive = os.path.join(download_root, filename)
print(f"Extracting {archive} to {extract_root}")
extract_archive(archive, extract_root, remove_finished)


def iterable_to_str(iterable: Iterable) -> str:
return "'" + "', '".join([str(item) for item in iterable]) + "'"


T = TypeVar("T", str, bytes)


def verify_str_arg(
value: T,
arg: Optional[str] = None,
valid_values: Iterable[T] = None,
custom_msg: Optional[str] = None,
) -> T:
if not isinstance(value, (str, bytes)):
if arg is None:
msg = "Expected type str, but got type {type}."
else:
msg = "Expected type str for argument {arg}, but got type {type}."
msg = msg.format(type=type(value), arg=arg)
raise ValueError(msg)

if valid_values is None:
return value

if value not in valid_values:
if custom_msg is not None:
msg = custom_msg
else:
msg = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}."
msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values))
raise ValueError(msg)

return value

+ 8
- 7
brainpy/dyn/__init__.py View File

@@ -1,15 +1,16 @@
# -*- coding: utf-8 -*-

"""
Dynamics simulation module.
Module for brain dynamics model building.
"""


from .neurons import *
from .synapses import *
from .channels import *
from .base import *
from .utils import *
from .neurons.compat import *
from .synapses.compat import *
from .runners import *

from . import neurons, synapses, channels, rates, utils, runners
from . import (channels, neurons, rates, # neuron related
synapses, synouts, synplast, # synapse related
networks,
layers, # ANN related
runners)

+ 999
- 388
brainpy/dyn/base.py View File

@@ -1,96 +1,110 @@
# -*- coding: utf-8 -*-

import math as pm
import warnings
from typing import Union, Dict, Callable, Sequence
import gc
import inspect
from typing import Union, Dict, Callable, Sequence, Optional, Tuple, Any
import collections

import jax.numpy as jnp
import numpy as np

import brainpy.math as bm
from brainpy import tools
from brainpy.algorithms import OnlineAlgorithm, OfflineAlgorithm
from brainpy.base.base import Base
from brainpy.base.collector import Collector
from brainpy.connect import TwoEndConnector, MatConn, IJConn
from brainpy.errors import ModelBuildError
from brainpy.initialize import Initializer, init_param, Uniform
from brainpy.integrators import Integrator, odeint
from brainpy.tools.others import to_size, size2num
from brainpy.types import Tensor, Shape
from brainpy.connect import TwoEndConnector, MatConn, IJConn, One2One, All2All
from brainpy.errors import ModelBuildError, NoImplementationError, UnsupportedError, MathError
from brainpy.initialize import Initializer, parameter, variable, Uniform, noise as init_noise
from brainpy.integrators import odeint, sdeint
from brainpy.modes import Mode, TrainingMode, BatchingMode, normal
from brainpy.tools.others import to_size, size2num, numba_jit, DotDict
from brainpy.types import Array, Shape

__all__ = [
# general class
'DynamicalSystem',
'Container',
'Network',
'ConstantDelay',
'NeuGroup',
'ConNeuGroup',
'TwoEndConn',

# containers
'Container', 'Network', 'Sequential', 'System',

# channel models
'Channel',

'ContainerWrapper',
# neuron models
'NeuGroup', 'CondNeuGroup',

# synapse models
'SynConn',
'TwoEndConn',
'SynOut', 'NullSynOut',
'SynSTP', 'NullSynSTP',
'SynLTP', 'NullSynLTP',

# slice
'DSView', 'NeuGroupView',
]

_error_msg = 'Unknown type of the update function: {} ({}). ' \
'Currently, BrainPy only supports: \n' \
'1. function \n' \
'2. function name (str) \n' \
'3. tuple/dict of functions \n' \
'4. tuple of function names \n'
SLICE_VARS = 'slice_vars'


class DynamicalSystem(Base):
"""Base Dynamical System class.

Any object has step functions will be a dynamical system.
That is to say, in BrainPy, the essence of the dynamical system
is the "step functions".

Parameters
----------
name : str, optional
The name of the dynamic system.
name : optional, str
The name of the dynamical system.
mode: Mode
The model computation mode. It should be instance of :py:class:`~.Mode`.
"""

"""Global delay variables. Useful when the same target
variable is used in multiple mappings."""
global_delay_vars: Dict[str, bm.LengthDelay] = dict()
'''Online fitting method.'''
online_fit_by: Optional[OnlineAlgorithm]

'''Offline fitting method.'''
offline_fit_by: Optional[OfflineAlgorithm]

'''Global delay data, which stores the delay variables and corresponding delay targets.
This variable is useful when the same target variable is used in multiple mappings,
as it can reduce the duplicate delay variable registration.'''
global_delay_data: Dict[str, Tuple[Union[bm.LengthDelay, None], bm.Variable]] = dict()

def __init__(
self,
name: str = None,
mode: Optional[Mode] = None,
):
# mode setting
if mode is None: mode = normal
if not isinstance(mode, Mode):
raise ValueError(f'Should be instance of {Mode.__name__}, but we got {type(Mode)}: {Mode}')
self._mode = mode

def __init__(self, name=None):
super(DynamicalSystem, self).__init__(name=name)

# local delay variables
self.local_delay_vars: Dict[str, bm.LengthDelay] = dict()
self.local_delay_vars: Dict[str, bm.LengthDelay] = Collector()

def __repr__(self):
return f'{self.__class__.__name__}(name={self.name})'
# fitting parameters
self.online_fit_by = None
self.offline_fit_by = None
self.fit_record = dict()

@property
def steps(self):
warnings.warn('.steps has been deprecated since version 2.0.3.', DeprecationWarning)
return {}

def ints(self, method='absolute'):
"""Collect all integrators in this node and the children nodes.
def mode(self) -> Mode:
"""Mode of the model, which is useful to control the multiple behaviors of the model."""
return self._mode

Parameters
----------
method : str
The method to access the integrators.
@mode.setter
def mode(self, value):
if not isinstance(value, Mode):
raise ValueError(f'Must be instance of {Mode.__name__}, but we got {type(value)}: {value}')
self._mode = value

Returns
-------
collector : Collector
The collection contained (the path, the integrator).
"""
nodes = self.nodes(method=method)
gather = Collector()
for node_path, node in nodes.items():
for k in dir(node):
v = getattr(node, k)
if isinstance(v, Integrator):
gather[f'{node_path}.{k}' if node_path else k] = v
return gather
def __repr__(self):
return f'{self.__class__.__name__}(name={self.name}, mode={self.mode})'

def __call__(self, *args, **kwargs):
"""The shortcut to call ``update`` methods."""
@@ -98,26 +112,23 @@ class DynamicalSystem(Base):

def register_delay(
self,
name: str,
delay_step: Union[int, Tensor, Callable, Initializer],
delay_target: Union[bm.JaxArray, jnp.ndarray],
initial_delay_data: Union[Initializer, Callable, Tensor, float, int, bool] = None,
domain: str = 'global'
identifier: str,
delay_step: Optional[Union[int, Array, Callable, Initializer]],
delay_target: bm.Variable,
initial_delay_data: Union[Initializer, Callable, Array, float, int, bool] = None,
):
"""Register delay variable.

Parameters
----------
name: str
identifier: str
The delay variable name.
delay_step: int, JaxArray, ndarray, callable, Initializer
delay_step: Optional, int, JaxArray, ndarray, callable, Initializer
The number of the steps of the delay.
delay_target: JaxArray, ndarray, Variable
The target for delay.
delay_target: Variable
The target variable for delay.
initial_delay_data: float, int, JaxArray, ndarray, callable, Initializer
The initializer for the delay data.
domain: str
The domain of the delay data to store.

Returns
-------
@@ -126,14 +137,17 @@ class DynamicalSystem(Base):
"""
# delay steps
if delay_step is None:
return delay_step
elif isinstance(delay_step, int):
delay_type = 'none'
elif isinstance(delay_step, (int, np.integer, jnp.integer)):
delay_type = 'homo'
elif isinstance(delay_step, (bm.ndarray, jnp.ndarray, np.ndarray)):
delay_type = 'heter'
delay_step = bm.asarray(delay_step)
if delay_step.size == 1 and delay_step.ndim == 0:
delay_type = 'homo'
else:
delay_type = 'heter'
delay_step = bm.asarray(delay_step)
elif callable(delay_step):
delay_step = init_param(delay_step, delay_target.shape, allow_none=False)
delay_step = parameter(delay_step, delay_target.shape, allow_none=False)
delay_type = 'heter'
else:
raise ValueError(f'Unknown "delay_steps" type {type(delay_step)}, only support '
@@ -145,44 +159,48 @@ class DynamicalSystem(Base):
'then provide us the number of delay steps.')
if delay_target.shape[0] != delay_step.shape[0]:
raise ValueError(f'Shape is mismatched: {delay_target.shape[0]} != {delay_step.shape[0]}')
max_delay_step = int(bm.max(delay_step))
if delay_type != 'none':
max_delay_step = int(bm.max(delay_step))

# delay domain
if domain not in ['global', 'local']:
raise ValueError('"domain" must be a string in ["global", "local"]. '
f'Bug we got {domain}.')
# delay target
if delay_type != 'none':
if not isinstance(delay_target, bm.Variable):
raise ValueError(f'"delay_target" must be an instance of Variable, but we got {type(delay_target)}')

# delay variable
if domain == 'local':
self.local_delay_vars[name] = bm.LengthDelay(delay_target, max_delay_step, initial_delay_data)
self.register_implicit_nodes(self.local_delay_vars)
else:
if name not in self.global_delay_vars:
self.global_delay_vars[name] = bm.LengthDelay(delay_target, max_delay_step, initial_delay_data)
# save into local delay vars when first seen "var",
# for later update current value!
self.local_delay_vars[name] = self.global_delay_vars[name]
if delay_type != 'none':
if identifier not in self.global_delay_data:
delay = bm.LengthDelay(delay_target, max_delay_step, initial_delay_data)
self.global_delay_data[identifier] = (delay, delay_target)
self.local_delay_vars[identifier] = delay
else:
if self.global_delay_vars[name].num_delay_step - 1 < max_delay_step:
self.global_delay_vars[name].reset(delay_target, max_delay_step, initial_delay_data)
self.register_implicit_nodes(self.global_delay_vars)
delay = self.global_delay_data[identifier][0]
if delay is None:
delay = bm.LengthDelay(delay_target, max_delay_step, initial_delay_data)
self.global_delay_data[identifier] = (delay, delay_target)
self.local_delay_vars[identifier] = delay
elif delay.num_delay_step - 1 < max_delay_step:
self.global_delay_data[identifier][0].reset(delay_target, max_delay_step, initial_delay_data)
else:
self.global_delay_data[identifier] = (None, delay_target)
self.register_implicit_nodes(self.local_delay_vars)
return delay_step

def get_delay_data(
self,
name: str,
delay_step: Union[int, bm.JaxArray, jnp.DeviceArray],
indices: Union[int, bm.JaxArray, jnp.DeviceArray] = None,
identifier: str,
delay_step: Optional[Union[int, bm.JaxArray, jnp.DeviceArray]],
*indices: Union[int, slice, bm.JaxArray, jnp.DeviceArray],
):
"""Get delay data according to the provided delay steps.

Parameters
----------
name: str
identifier: str
The delay variable name.
delay_step: int, JaxArray, ndarray
delay_step: Optional, int, JaxArray, ndarray
The delay length.
indices: optional, int, JaxArray, ndarray
indices: optional, int, slice, JaxArray, ndarray
The indices of the delay.

Returns
@@ -190,66 +208,143 @@ class DynamicalSystem(Base):
delay_data: JaxArray, ndarray
The delay data at the given time.
"""
if name in self.global_delay_vars:
if isinstance(delay_step, int):
return self.global_delay_vars[name](delay_step, indices)
if delay_step is None:
return self.global_delay_data[identifier][1].value

if identifier in self.global_delay_data:
if bm.ndim(delay_step) == 0:
return self.global_delay_data[identifier][0](delay_step, *indices)
else:
if indices is None:
indices = jnp.arange(delay_step.size)
return self.global_delay_vars[name](delay_step, indices)
elif name in self.local_delay_vars:
if isinstance(delay_step, int):
return self.local_delay_vars[name](delay_step)
if len(indices) == 0:
indices = (jnp.arange(delay_step.size),)
return self.global_delay_data[identifier][0](delay_step, *indices)

elif identifier in self.local_delay_vars:
if bm.ndim(delay_step) == 0:
return self.local_delay_vars[identifier](delay_step)
else:
if indices is None:
indices = jnp.arange(delay_step.size)
return self.local_delay_vars[name](delay_step, indices)
if len(indices) == 0:
indices = (jnp.arange(delay_step.size),)
return self.local_delay_vars[identifier](delay_step, *indices)

else:
raise ValueError(f'{name} is not defined in delay variables.')
raise ValueError(f'{identifier} is not defined in delay variables.')

def update_delay(
self,
name: str,
delay_data: Union[float, bm.JaxArray, jnp.ndarray]
):
"""Update the delay according to the delay data.
def update(self, *args, **kwargs):
"""The function to specify the updating rule.

Parameters
----------
name: str
The name of the delay.
delay_data: float, JaxArray, ndarray
The delay data to update at the current time.
Assume any dynamical system depends on the shared variables (`sha`),
like time variable ``t``, the step precision ``dt``, and the time step `i`.
"""
if name in self.local_delay_vars:
return self.local_delay_vars[name].update(delay_data)
else:
if name not in self.global_delay_vars:
raise ValueError(f'{name} is not defined in delay variables.')
raise NotImplementedError('Must implement "update" function by subclass self.')

def reset_delay(
self,
name: str,
delay_target: Union[bm.JaxArray, jnp.DeviceArray]
):
"""Reset the delay variable."""
if name in self.local_delay_vars:
return self.local_delay_vars[name].reset(delay_target)
def reset(self, batch_size=None):
"""Reset function which reset the whole variables in the model.
"""
self.reset_state(batch_size)

def reset_state(self, batch_size=None):
"""Reset function which reset the states in the model.
"""
child_nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique()
if len(child_nodes) > 0:
for node in child_nodes.values():
node.reset_state(batch_size=batch_size)
self.reset_local_delays(child_nodes)
else:
if name not in self.global_delay_vars:
raise ValueError(f'{name} is not defined in delay variables.')
raise NotImplementedError('Must implement "reset_state" function by subclass self. '
f'Error of {self.name}')

def update(self, t, dt):
"""The function to specify the updating rule.
Assume any dynamical system depends on the time variable ``t`` and
the time step ``dt``.
def update_local_delays(self, nodes: Union[Sequence, Dict] = None):
"""Update local delay variables.

This function should be called after updating neuron groups or delay sources.
For example, in a network model,


Parameters
----------
nodes: sequence, dict
The nodes to update their delay variables.
"""
raise NotImplementedError('Must implement "update" function by subclass self.')
# update delays
if nodes is None:
nodes = tuple(self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values())
elif isinstance(nodes, DynamicalSystem):
nodes = (nodes, )
elif isinstance(nodes, dict):
nodes = tuple(nodes.values())
if not isinstance(nodes, (tuple, list)):
raise ValueError('Please provide nodes as a list/tuple/dict of DynamicalSystem.')
for node in nodes:
for name in node.local_delay_vars:
delay = self.global_delay_data[name][0]
target = self.global_delay_data[name][1]
delay.update(target.value)

def reset_local_delays(self, nodes: Union[Sequence, Dict] = None):
"""Reset local delay variables.

def reset(self):
"""Reset function which reset the whole variables in the model.
Parameters
----------
nodes: sequence, dict
The nodes to Reset their delay variables.
"""
raise NotImplementedError('Must implement "reset" function by subclass self.')
# reset delays
if nodes is None:
nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values()
elif isinstance(nodes, dict):
nodes = nodes.values()
for node in nodes:
for name in node.local_delay_vars:
delay = self.global_delay_data[name][0]
target = self.global_delay_data[name][1]
delay.reset(target.value)

def __del__(self):
"""Function for handling `del` behavior.

This function is used to pop out the variables which registered in global delay data.
"""
if hasattr(self, 'local_delay_vars'):
for key in tuple(self.local_delay_vars.keys()):
val = self.global_delay_data.pop(key)
del val
val = self.local_delay_vars.pop(key)
del val
if hasattr(self, 'implicit_nodes'):
for key in tuple(self.implicit_nodes.keys()):
del self.implicit_nodes[key]
if hasattr(self, 'implicit_vars'):
for key in tuple(self.implicit_vars.keys()):
del self.implicit_vars[key]
for key in tuple(self.__dict__.keys()):
del self.__dict__[key]
gc.collect()

@tools.not_customized
def online_init(self):
raise NoImplementationError('Subclass must implement online_init() function when using OnlineTrainer.')

@tools.not_customized
def offline_init(self):
raise NoImplementationError('Subclass must implement offline_init() function when using OfflineTrainer.')

@tools.not_customized
def online_fit(self,
target: Array,
fit_record: Dict[str, Array]):
raise NoImplementationError('Subclass must implement online_fit() function when using OnlineTrainer.')

@tools.not_customized
def offline_fit(self,
target: Array,
fit_record: Dict[str, Array]):
raise NoImplementationError('Subclass must implement offline_fit() function when using OfflineTrainer.')

def clear_input(self):
for node in self.nodes(level=1, include_self=False).subset(NeuGroup).unique().values():
node.clear_input()


class Container(DynamicalSystem):
@@ -269,42 +364,56 @@ class Container(DynamicalSystem):
The instance of DynamicalSystem with the format of "key=dynamic_system".
"""

def __init__(self, *ds_tuple, name=None, **ds_dict):
super(Container, self).__init__(name=name)

# children dynamical systems
self.implicit_nodes = Collector()
for ds in ds_tuple:
if not isinstance(ds, DynamicalSystem):
raise ModelBuildError(f'{self.__class__.__name__} receives instances of '
f'DynamicalSystem, however, we got {type(ds)}.')
if ds.name in self.implicit_nodes:
raise ValueError(f'{ds.name} has been paired with {ds}. Please change a unique name.')
self.register_implicit_nodes({node.name: node for node in ds_tuple})
for key, ds in ds_dict.items():
if not isinstance(ds, DynamicalSystem):
raise ModelBuildError(f'{self.__class__.__name__} receives instances of '
f'DynamicalSystem, however, we got {type(ds)}.')
if key in self.implicit_nodes:
raise ValueError(f'{key} has been paired with {ds}. Please change a unique name.')
self.register_implicit_nodes(ds_dict)
def __init__(
self,
*ds_tuple,
name: str = None,
mode: Mode = normal,
**ds_dict
):
super(Container, self).__init__(name=name, mode=mode)

# add tuple-typed components
for module in ds_tuple:
if isinstance(module, DynamicalSystem):
self.implicit_nodes[module.name] = module
elif isinstance(module, (list, tuple)):
for m in module:
if not isinstance(m, DynamicalSystem):
raise ValueError(f'Should be instance of {DynamicalSystem.__name__}. '
f'But we got {type(m)}')
self.implicit_nodes[m.name] = module
elif isinstance(module, dict):
for k, v in module.items():
if not isinstance(v, DynamicalSystem):
raise ValueError(f'Should be instance of {DynamicalSystem.__name__}. '
f'But we got {type(v)}')
self.implicit_nodes[k] = v
else:
raise ValueError(f'Cannot parse sub-systems. They should be {DynamicalSystem.__name__} '
f'or a list/tuple/dict of {DynamicalSystem.__name__}.')
# add dict-typed components
for k, v in ds_dict.items():
if not isinstance(v, DynamicalSystem):
raise ValueError(f'Should be instance of {DynamicalSystem.__name__}. '
f'But we got {type(v)}')
self.implicit_nodes[k] = v

def __repr__(self):
cls_name = self.__class__.__name__
# split = '\n' + (' ' * (len(cls_name) + 1))
split = ', '
children = [f'{key}={str(val)}' for key, val in self.implicit_nodes.items()]
return f'{cls_name}({split.join(children)})'

def update(self, t, dt):
"""Step function of a network.
def update(self, tdi, *args, **kwargs):
"""Update function of a container.

In this update function, the update functions in children systems are
iteratively called.
"""
nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique()
for node in nodes.values():
node.update(t, dt)
node.update(tdi)

def __getitem__(self, item):
"""Wrap the slice access (self['']). """
@@ -321,30 +430,91 @@ class Container(DynamicalSystem):
else:
return super(Container, self).__getattribute__(item)

def reset(self):
nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique()
neuron_groups = nodes.subset(NeuGroup)
synapse_groups = nodes.subset(TwoEndConn)
for node in neuron_groups.values():
node.reset()
for node in synapse_groups.values():
node.reset()
for node in (nodes - neuron_groups - synapse_groups).values():
node.reset()

@classmethod
def has(cls, **children_cls):
"""
class Sequential(Container):
def __init__(
self,
*modules,
name: str = None,
mode: Mode = normal,
**kw_modules
):
super(Sequential, self).__init__(*modules, name=name, mode=mode, **kw_modules)

def __getattr__(self, item):
"""Wrap the dot access ('self.'). """
child_ds = super(Sequential, self).__getattribute__('implicit_nodes')
if item in child_ds:
return child_ds[item]
else:
return super(Sequential, self).__getattribute__(item)

def __getitem__(self, key: Union[int, slice]):
if isinstance(key, str):
if key not in self.implicit_nodes:
raise KeyError(f'Does not find a component named {key} in\n {str(self)}')
return self.implicit_nodes[key]
elif isinstance(key, slice):
keys = tuple(self.implicit_nodes.keys())[key]
components = tuple(self.implicit_nodes.values())[key]
return Sequential(dict(zip(keys, components)))
elif isinstance(key, int):
return self.implicit_nodes.values()[key]
elif isinstance(key, (tuple, list)):
all_keys = tuple(self.implicit_nodes.keys())
all_vals = tuple(self.implicit_nodes.values())
keys, vals = [], []
for i in key:
if isinstance(i, int):
raise KeyError(f'We excepted a tuple/list of int, but we got {type(i)}')
keys.append(all_keys[i])
vals.append(all_vals[i])
return Sequential(dict(zip(keys, vals)))
else:
raise KeyError(f'Unknown type of key: {type(key)}')

def __repr__(self):
def f(x):
if not isinstance(x, DynamicalSystem) and callable(x):
signature = inspect.signature(x)
args = [f'{k}={v.default}' for k, v in signature.parameters.items()
if v.default is not inspect.Parameter.empty]
args = ', '.join(args)
while not hasattr(x, '__name__'):
if not hasattr(x, 'func'):
break
x = x.func # Handle functools.partial
if not hasattr(x, '__name__') and hasattr(x, '__class__'):
return x.__class__.__name__
if args:
return f'{x.__name__}(*, {args})'
return x.__name__
else:
x = repr(x).split('\n')
x = [x[0]] + [' ' + y for y in x[1:]]
return '\n'.join(x)

entries = '\n'.join(f' [{i}] {f(x)}' for i, x in enumerate(self))
return f'{self.__class__.__name__}(\n{entries}\n)'

def update(self, sha: dict, x: Any) -> Array:
"""Update function of a sequential model.

Parameters
----------
children_cls
sha: dict
The shared arguments (ShA) across multiple layers.
x: Any
The input information.

Returns
-------

y: Array
The output tensor.
"""
return ContainerWrapper(master=cls, **children_cls)
for node in self.implicit_nodes.values():
x = node(sha, x)
return x


class Network(Container):
@@ -365,116 +535,73 @@ class Network(Container):
A dict container of dynamical system.
"""

def __init__(self, *ds_tuple, name=None, **ds_dict):
super(Network, self).__init__(*ds_tuple, name=name, **ds_dict)
def __init__(
self,
*ds_tuple,
name: str = None,
mode: Mode = normal,
**ds_dict
):
super(Network, self).__init__(*ds_tuple,
name=name,
mode=mode,
**ds_dict)

def update(self, *args, **kwargs):
"""Step function of a network.

class ConstantDelay(DynamicalSystem):
"""Class used to model constant delay variables.
In this update function, the update functions in children systems are
iteratively called.
"""
nodes = self.nodes(level=1, include_self=False)
nodes = nodes.subset(DynamicalSystem)
nodes = nodes.unique()
neuron_groups = nodes.subset(NeuGroup)
synapse_groups = nodes.subset(SynConn)
ds_views = nodes.subset(DSView)
other_nodes = nodes - neuron_groups - synapse_groups - ds_views

This class automatically supports batch size on the last axis. For example, if
you run batch with the size of (10, 100), where `100` are batch size, then this
class can automatically support your batched data.
For examples,
# shared arguments
shared = args[0]

>>> import brainpy as bp
>>> bp.dyn.ConstantDelay(size=(10, 100), delay=10.)
# update synapse nodes
for node in synapse_groups.values():
node.update(shared)

This class also support nonuniform delays.
# update neuron nodes
for node in neuron_groups.values():
node.update(shared)

>>> bp.dyn.ConstantDelay(size=100, delay=bp.math.random.random(100) * 4 + 10)
# update other types of nodes
for node in other_nodes.values():
node.update(shared)

Parameters
----------
size : int, list of int, tuple of int
The delay data size.
delay : int, float, function, ndarray
The delay time. With the unit of `dt`.
dt: float, optional
The time precision.
name : optional, str
The name of the dynamic system.
"""
# update delays
self.update_local_delays(nodes)

def __init__(self, size, delay, dtype=None, dt=None, **kwargs):
# dt
self.dt = bm.get_dt() if dt is None else dt
self.dtype = dtype

# data size
if isinstance(size, int): size = (size,)
if not isinstance(size, (tuple, list)):
raise ModelBuildError(f'"size" must a tuple/list of int, but we got {type(size)}: {size}')
self.size = tuple(size)

# delay time length
self.delay = delay

# data and operations
if isinstance(delay, (int, float)): # uniform delay
self.uniform_delay = True
self.num_step = int(pm.ceil(delay / self.dt)) + 1
self.out_idx = bm.Variable(bm.array([0], dtype=bm.uint32))
self.in_idx = bm.Variable(bm.array([self.num_step - 1], dtype=bm.uint32))
self.data = bm.Variable(bm.zeros((self.num_step,) + self.size, dtype=dtype))
self.num = 1

else: # non-uniform delay
self.uniform_delay = False
if not len(self.size) == 1:
raise NotImplementedError(f'Currently, BrainPy only supports 1D heterogeneous '
f'delays, while we got the heterogeneous delay with '
f'{len(self.size)}-dimensions.')
self.num = tools.size2num(size)
if bm.ndim(delay) != 1:
raise ModelBuildError(f'Only support a 1D non-uniform delay. '
f'But we got {delay.ndim}D: {delay}')
if delay.shape[0] != self.size[0]:
raise ModelBuildError(f"The first shape of the delay time size must "
f"be the same with the delay data size. But "
f"we got {delay.shape[0]} != {self.size[0]}")
delay = bm.around(delay / self.dt)
self.diag = bm.array(bm.arange(self.num), dtype=bm.int_)
self.num_step = bm.array(delay, dtype=bm.uint32) + 1
self.in_idx = bm.Variable(self.num_step - 1)
self.out_idx = bm.Variable(bm.zeros(self.num, dtype=bm.uint32))
self.data = bm.Variable(bm.zeros((self.num_step.max(),) + size, dtype=dtype))

super(ConstantDelay, self).__init__(**kwargs)

def reset(self):
"""Reset the variables."""
self.in_idx[:] = self.num_step - 1
self.out_idx[:] = 0
self.data[:] = 0
def reset_state(self, batch_size=None):
nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique()
neuron_groups = nodes.subset(NeuGroup)
synapse_groups = nodes.subset(SynConn)

@property
def oldest(self):
return self.pull()
# reset neuron nodes
for node in neuron_groups.values():
node.reset_state(batch_size)

@property
def latest(self):
if self.uniform_delay:
return self.data[self.in_idx[0]]
else:
return self.data[self.in_idx, self.diag]
# reset synapse nodes
for node in synapse_groups.values():
node.reset_state(batch_size)

def pull(self):
if self.uniform_delay:
return self.data[self.out_idx[0]]
else:
return self.data[self.out_idx, self.diag]
# reset other types of nodes
for node in (nodes - neuron_groups - synapse_groups).values():
node.reset_state(batch_size)

# reset delays
self.reset_local_delays(nodes)

def push(self, value):
if self.uniform_delay:
self.data[self.in_idx[0]] = value
else:
self.data[self.in_idx, self.diag] = value

def update(self, t=None, dt=None, **kwargs):
"""Update the delay index."""
self.in_idx[:] = (self.in_idx + 1) % self.num_step
self.out_idx[:] = (self.out_idx + 1) % self.num_step
class System(Network):
pass


class NeuGroup(DynamicalSystem):
@@ -494,68 +621,103 @@ class NeuGroup(DynamicalSystem):
The neuron group geometry.
name : optional, str
The name of the dynamic system.
keep_size: bool
Whether keep the geometry information.

.. versionadded:: 2.1.13
mode: Mode
.. versionadded:: 2.2.0
"""

def __init__(self,
size: Shape,
name: str = None):
def __init__(
self,
size: Shape,
keep_size: bool = False,
name: str = None,
mode: Mode = normal,
):
# size
if isinstance(size, (list, tuple)):
if len(size) <= 0:
raise ModelBuildError(f'size must be int, or a tuple/list of int. '
f'But we got {type(size)}')
if not isinstance(size[0], int):
if not isinstance(size[0], (int, np.integer)):
raise ModelBuildError('size must be int, or a tuple/list of int.'
f'But we got {type(size)}')
size = tuple(size)
elif isinstance(size, int):
elif isinstance(size, (int, np.integer)):
size = (size,)
else:
raise ModelBuildError('size must be int, or a tuple/list of int.'
f'But we got {type(size)}')
self.size = size
self.keep_size = keep_size
# number of neurons
self.num = tools.size2num(size)

# initialize
super(NeuGroup, self).__init__(name=name)
super(NeuGroup, self).__init__(name=name, mode=mode)

def update(self, t, dt):
@property
def varshape(self):
"""The shape of variables in the neuron group."""
return self.size if self.keep_size else (self.num,)

def __repr__(self):
return f'{self.__class__.__name__}(name={self.name}, mode={self.mode}, size={self.size})'

def get_batch_shape(self, batch_size=None):
if batch_size is None:
return self.varshape
else:
return (batch_size,) + self.varshape

def update(self, tdi, x=None):
"""The function to specify the updating rule.

Parameters
----------
t : float
The current time.
dt : float
The time step.
tdi : DotDict
The shared arguments, especially time `t`, step `dt`, and iteration `i`.
x: Any
The input for a neuron group.
"""
raise NotImplementedError(f'Subclass of {self.__class__.__name__} must '
f'implement "update" function.')

def clear_input(self):
"""Function to clear inputs in the neuron group.
It will be useful when monitoring inputs of the object received."""
pass

def __getitem__(self, item):
return NeuGroupView(target=self, index=item)

class TwoEndConn(DynamicalSystem):

class SynConn(DynamicalSystem):
"""Base class to model two-end synaptic connections.

Parameters
----------
pre : NeuGroup
Pre-synaptic neuron group.
Pre-synaptic neuron group.
post : NeuGroup
Post-synaptic neuron group.
Post-synaptic neuron group.
conn : optional, ndarray, JaxArray, dict, TwoEndConnector
The connection method between pre- and post-synaptic groups.
The connection method between pre- and post-synaptic groups.
name : str, optional
The name of the dynamic system.
The name of the dynamic system.
"""

def __init__(
self,
pre: NeuGroup,
post: NeuGroup,
conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]] = None,
name: str = None
conn: Union[TwoEndConnector, Array, Dict[str, Array]] = None,
name: str = None,
mode: Mode = normal,
):
super(SynConn, self).__init__(name=name, mode=mode)

# pre or post neuron group
# ------------------------
@@ -589,9 +751,11 @@ class TwoEndConn(DynamicalSystem):
else:
raise ModelBuildError(f'Unknown "conn" type: {conn}')

# initialize
# ----------
super(TwoEndConn, self).__init__(name=name)
def __repr__(self):
names = self.__class__.__name__
return (f'{names}(name={self.name}, mode={self.mode}, '
f'{" " * len(names)} pre={self.pre}, '
f'{" " * len(names)} post={self.post})')

def check_pre_attrs(self, *attrs):
"""Check whether pre group satisfies the requirement."""
@@ -613,31 +777,281 @@ class TwoEndConn(DynamicalSystem):
if not hasattr(self.post, attr):
raise ModelBuildError(f'{self} need "pre" neuron group has attribute "{attr}".')

def update(self, tdi, pre_spike=None):
"""The function to specify the updating rule.

Assume any dynamical system depends on the shared variables (`sha`),
like time variable ``t``, the step precision ``dt``, and the time step `i`.
"""
raise NotImplementedError('Must implement "update" function by subclass self.')


class SynComponent(DynamicalSystem):
"""Base class for modeling synaptic components,
including synaptic output, synaptic short-term plasticity,
synaptic long-term plasticity, and others. """

'''Master of this component.'''
master: SynConn

def __init__(self, *args, **kwargs):
super(SynComponent, self).__init__(*args, **kwargs)

self._registered = False

@property
def isregistered(self) -> bool:
"""State of the component, representing whether it has been registered."""
return self._registered

@isregistered.setter
def isregistered(self, val: bool):
if not isinstance(val, bool):
raise ValueError('Must be an instance of bool.')
self._registered = val

def reset_state(self, batch_size=None):
pass

def register_master(self, master: SynConn):
if not isinstance(master, SynConn):
raise TypeError(f'master must be instance of {SynConn.__name__}, but we got {type(master)}')
if self.isregistered:
raise ValueError(f'master has been registered, but we got another master going to be registered.')
if hasattr(self, 'master') and self.master != master:
raise ValueError(f'master has been registered, but we got another master going to be registered.')
self.master = master
self._registered = True

def __repr__(self):
return self.__class__.__name__

def __call__(self, *args, **kwargs):
return self.filter(*args, **kwargs)

def clone(self) -> 'SynComponent':
"""The function useful to clone a new object when it has been used."""
raise NotImplementedError

def filter(self, g):
raise NotImplementedError

class Channel(DynamicalSystem):
"""Abstract channel model."""

class SynOut(SynComponent):
"""Base class for synaptic current output."""

def __init__(
self,
size: Union[int, Sequence[int]],
name: str = None,
target_var: Union[str, bm.Variable] = None,
):
super(Channel, self).__init__(name=name)
self.size = to_size(size)
self.num = size2num(self.size)
super(SynOut, self).__init__(name=name)
# check target variable
if target_var is not None:
if not isinstance(target_var, (str, bm.Variable)):
raise TypeError('"target_var" must be instance of string or Variable. '
f'But we got {type(target_var)}')
self.target_var: Optional[bm.Variable] = target_var

def register_master(self, master: SynConn):
super(SynOut, self).register_master(master)

# initialize target variable to output
if isinstance(self.target_var, str):
if not hasattr(self.master.post, self.target_var):
raise KeyError(f'Post-synaptic group does not have target variable: {self.target_var}')
self.target_var = getattr(self.master.post, self.target_var)

def filter(self, g):
if self.target_var is None:
return g
else:
self.target_var += g

def update(self, t, dt):
raise NotImplementedError('Must be implemented by the subclass.')
def update(self, tdi):
pass

def current(self):
raise NotImplementedError('Must be implemented by the subclass.')

def reset(self):
raise NotImplementedError('Must be implemented by the subclass.')
class SynSTP(SynComponent):
"""Base class for synaptic short-term plasticity."""

def update(self, tdi, pre_spike):
pass


class SynLTP(SynComponent):
"""Base class for synaptic long-term plasticity."""

def update(self, tdi, pre_spike):
pass


class NullSynOut(SynOut):
def clone(self):
return NullSynOut()


class NullSynSTP(SynSTP):
def clone(self):
return NullSynSTP()

def filter(self, g):
return g


class NullSynLTP(SynLTP):
def clone(self):
return NullSynLTP()

def filter(self, g):
return g


class TwoEndConn(SynConn):
"""Base class to model synaptic connections.

Parameters
----------
pre : NeuGroup
Pre-synaptic neuron group.
post : NeuGroup
Post-synaptic neuron group.
conn : optional, ndarray, JaxArray, dict, TwoEndConnector
The connection method between pre- and post-synaptic groups.
output: Optional, SynOutput
The output for the synaptic current.

.. versionadded:: 2.1.13
The output component for a two-end connection model.

stp: Optional, SynSTP
The short-term plasticity model for the synaptic variables.

.. versionadded:: 2.1.13
The short-term plasticity component for a two-end connection model.

ltp: Optional, SynLTP
The long-term plasticity model for the synaptic variables.

.. versionadded:: 2.1.13
The long-term plasticity component for a two-end connection model.

name: Optional, str
The name of the dynamic system.
"""

def __init__(
self,
pre: NeuGroup,
post: NeuGroup,
conn: Union[TwoEndConnector, Array, Dict[str, Array]] = None,
output: SynOut = NullSynOut(),
stp: SynSTP = NullSynSTP(),
ltp: SynLTP = NullSynLTP(),
name: str = None,
mode: Mode = normal,
):
super(TwoEndConn, self).__init__(pre=pre,
post=post,
conn=conn,
name=name,
mode=mode)

# synaptic output
output = NullSynOut() if output is None else output
if output.isregistered: output = output.clone()
if not isinstance(output, SynOut):
raise TypeError(f'output must be instance of {SynOut.__name__}, '
f'but we got {type(output)}')
output.register_master(master=self)
self.output: SynOut = output

# short-term synaptic plasticity
stp = NullSynSTP() if stp is None else stp
if stp.isregistered: stp = stp.clone()
if not isinstance(stp, SynSTP):
raise TypeError(f'Short-term plasticity must be instance of {SynSTP.__name__}, '
f'but we got {type(stp)}')
stp.register_master(master=self)
self.stp: SynSTP = stp

# long-term synaptic plasticity
ltp = NullSynLTP() if ltp is None else ltp
if ltp.isregistered: ltp = ltp.clone()
if not isinstance(ltp, SynLTP):
raise TypeError(f'Long-term plasticity must be instance of {SynLTP.__name__}, '
f'but we got {type(ltp)}')
ltp.register_master(master=self)
self.ltp: SynLTP = ltp

def init_weights(
self,
weight: Union[float, Array, Initializer, Callable],
comp_method: str,
sparse_data: str = 'csr'
) -> Union[float, Array]:
if comp_method not in ['sparse', 'dense']:
raise ValueError(f'"comp_method" must be in "sparse" and "dense", but we got {comp_method}')
if sparse_data not in ['csr', 'ij']:
raise ValueError(f'"sparse_data" must be in "csr" and "ij", but we got {sparse_data}')
if self.conn is None:
raise ValueError(f'Must provide "conn" when initialize the model {self.name}')

# connections and weights
if isinstance(self.conn, One2One):
weight = parameter(weight, (self.pre.num,), allow_none=False)
conn_mask = None

elif isinstance(self.conn, All2All):
weight = parameter(weight, (self.pre.num, self.post.num), allow_none=False)
conn_mask = None

else:
if comp_method == 'sparse':
if sparse_data == 'csr':
conn_mask = self.conn.require('pre2post')
elif sparse_data == 'ij':
conn_mask = self.conn.require('post_ids', 'pre_ids')
else:
ValueError(f'Unknown sparse data type: {sparse_data}')
weight = parameter(weight, conn_mask[1].shape, allow_none=False)
elif comp_method == 'dense':
weight = parameter(weight, (self.pre.num, self.post.num), allow_none=False)
conn_mask = self.conn.require('conn_mat')
else:
raise ValueError(f'Unknown connection type: {comp_method}')

# training weights
if isinstance(self.mode, TrainingMode):
weight = bm.TrainVar(weight)
return weight, conn_mask

def syn2post_with_all2all(self, syn_value, syn_weight):
if bm.ndim(syn_weight) == 0:
if isinstance(self.mode, BatchingMode):
post_vs = bm.sum(syn_value, keepdims=True, axis=tuple(range(syn_value.ndim))[1:])
else:
post_vs = bm.sum(syn_value)
if not self.conn.include_self:
post_vs = post_vs - syn_value
post_vs = syn_weight * post_vs
else:
post_vs = syn_value @ syn_weight
return post_vs

def syn2post_with_one2one(self, syn_value, syn_weight):
return syn_value * syn_weight

def syn2post_with_dense(self, syn_value, syn_weight, conn_mat):
if bm.ndim(syn_weight) == 0:
post_vs = (syn_weight * syn_value) @ conn_mat
else:
post_vs = syn_value @ (syn_weight * conn_mat)
return post_vs

class ConNeuGroup(NeuGroup, Container):
"""Base class to model conductance-based neuron group.

class CondNeuGroup(NeuGroup, Container):
r"""Base class to model conductance-based neuron group.

The standard formulation for a conductance-based model is given as

@@ -666,6 +1080,7 @@ class ConNeuGroup(NeuGroup, Container):
where :math:`\alpha_{x}` and :math:`\beta_{x}` are rate constants.

.. versionadded:: 2.1.9
Model the conductance-based neuron model.

Parameters
----------
@@ -676,109 +1091,305 @@ class ConNeuGroup(NeuGroup, Container):
name : optional, str
The neuron group name.

See Also
--------
Channel

"""

def __init__(
self,
size: Shape,
C: Union[float, Tensor, Initializer, Callable] = 1.,
A: Union[float, Tensor, Initializer, Callable] = 1e-3,
V_th: Union[float, Tensor, Initializer, Callable] = 0.,
V_initializer: Union[Initializer, Callable, Tensor] = Uniform(-70, -60.),
keep_size: bool = False,
C: Union[float, Array, Initializer, Callable] = 1.,
A: Union[float, Array, Initializer, Callable] = 1e-3,
V_th: Union[float, Array, Initializer, Callable] = 0.,
V_initializer: Union[Initializer, Callable, Array] = Uniform(-70, -60.),
noise: Union[float, Array, Initializer, Callable] = None,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
**channels
):
NeuGroup.__init__(self, size)
Container.__init__(self, **channels, name=name)
NeuGroup.__init__(self, size, keep_size=keep_size, mode=mode)
Container.__init__(self, **channels, name=name, mode=mode)

# parameters for neurons
self.C = C
self.A = A
self.V_th = V_th
self._V_initializer = V_initializer
self.noise = init_noise(noise, self.varshape, num_vars=3)

# variables
self.V = bm.Variable(init_param(V_initializer, self.num, allow_none=False))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.V = variable(V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)

# function
self.integral = odeint(self.derivative, method=method)

def reset(self):
self.V.value = init_param(self._V_initializer, self.num, allow_none=False)
self.spike[:] = False
self.input[:] = 0
if self.noise is None:
self.integral = odeint(f=self.derivative, method=method)
else:
self.integral = sdeint(f=self.derivative, g=self.noise, method=method)

def derivative(self, V, t):
Iext = self.input.value * (1e-3 / self.A)
for ch in self.implicit_nodes.values():
Iext += ch.current(V)
channels = self.nodes(level=1, include_self=False).subset(Channel).unique()
for ch in channels.values():
Iext = Iext + ch.current(V)
return Iext / self.C

def update(self, t, dt):
V = self.integral(self.V.value, t, dt)
for node in self.implicit_nodes.unique().values():
node.update(t, dt, self.V.value)
def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)

def update(self, tdi, *args, **kwargs):
V = self.integral(self.V.value, tdi['t'], tdi['dt'])

channels = self.nodes(level=1, include_self=False).subset(Channel).unique()
# check whether the children channels have the correct parents.
check_master(type(self), **channels)

# update variables
for node in channels.values():
node.update(tdi, self.V.value)
self.spike.value = bm.logical_and(V >= self.V_th, self.V < self.V_th)
self.input[:] = 0.
self.V.value = V

def register_implicit_nodes(self, *channels, **named_channels):
check_master(type(self), *channels, **named_channels)
super(CondNeuGroup, self).register_implicit_nodes(*channels, **named_channels)

class ContainerWrapper(object):
def __init__(self, master, **children):
self.master = master
self.children_cls = children

if not isinstance(master, type):
raise TypeError(f'"master" should be a type. But we got {master}')
# if not issubclass(master, Channel):
# raise TypeError(f'{master} should be a subclass of {Channel.__name__}.')
for key, child in children.items():
if isinstance(child, type):
if not issubclass(child, Channel):
raise TypeError(f'{child} should be a subclass of Base.')
if child.master_cls is None:
raise TypeError(f'{child} should set its master_cls.')
if not issubclass(master, child.master_cls):
raise TypeError(f'Type does not match. {child} requires a master with type '
f'of {child.master_cls}, but the master now is {master}.')
elif isinstance(child, ContainerWrapper):
if not issubclass(child.master, Channel):
raise TypeError(f'{child.master} should be a subclass of Base.')
if child.master.master_cls is None:
raise TypeError(f'{child.master} should set its master_cls.')
if not issubclass(master, child.master.master_cls):
raise TypeError(f'Type does not match. {child.master} requires a master with type '
f'of {child.master.master_cls}, but the master now is {master}.')
def clear_input(self):
"""Useful for monitoring inputs. """
self.input.value = bm.zeros_like(self.input)

else:
raise TypeError(f'The item in children should be a type or '
f'{ContainerWrapper.__name__} instance. But we got {child}')

def __call__(self, size, *shared_args, shared_kwargs=None, **idv_args):
if shared_kwargs is None:
shared_kwargs = dict()

# initialize children classes
children = dict()
for key, cls in self.children_cls.items():
if key in idv_args:
pars = idv_args.pop(key)
else:
pars = dict()
children[key] = cls(size, *shared_args, **shared_kwargs, **pars)

# initialize master class
master = self.master(size, *shared_args, **shared_kwargs, **idv_args, **children)
class Channel(DynamicalSystem):
"""Abstract channel class."""

# assign master or parent to children
for child in children.values():
child.master = master
master_type = CondNeuGroup

return master
def __init__(
self,
size: Union[int, Sequence[int]],
name: str = None,
keep_size: bool = False,
mode: Mode = normal,
):
super(Channel, self).__init__(name=name, mode=mode)
# the geometry size
self.size = to_size(size)
# the number of elements
self.num = size2num(self.size)
# variable shape
self.keep_size = keep_size

@property
def varshape(self):
return self.size if self.keep_size else self.num

def update(self, tdi, V):
raise NotImplementedError('Must be implemented by the subclass.')

def current(self, V):
raise NotImplementedError('Must be implemented by the subclass.')

def reset_state(self, batch_size=None):
raise NotImplementedError('Must be implemented by the subclass.')


def _check(master, child):
if not hasattr(child, 'master_type'):
raise ValueError('Child class should define "master_type" to specify the type of the master. '
f'But we did not found it in {child}')
if not issubclass(master, child.master_type):
raise TypeError(f'Type does not match. {child} requires a master with type '
f'of {child.master_type}, but the master now is {master}.')


def check_master(master, *channels, **named_channels):
for channel in channels:
if isinstance(channel, Channel):
_check(master, channel)
elif isinstance(channel, (list, tuple)):
for ch in channel:
_check(master, ch)
elif isinstance(channel, dict):
for ch in channel.values():
_check(master, ch)
else:
raise ValueError(f'Do not support {type(channel)}.')
for channel in named_channels.values():
if not isinstance(channel, Channel):
raise ValueError(f'Do not support {type(channel)}. ')
_check(master, channel)


class DSView(DynamicalSystem):
"""DSView, an object used to get a view of a dynamical system instance.

It can get a subset view of variables in a dynamical system instance.
For instance,

>>> import brainpy as bp
>>> hh = bp.dyn.HH(10)
>>> bp.dyn.DSView(hh, slice(5, 10, None))
>>> # or, simply
>>> hh[5:]
"""

def __init__(
self,
target: DynamicalSystem,
index: Union[slice, Sequence, Array],
varshape: Tuple[int, ...] = None,
name: str = None,
mode: Mode = None
):
# initialization
DynamicalSystem.__init__(self, name=name, mode=mode)

# check target
if not isinstance(target, DynamicalSystem):
raise TypeError(f'Should be instance of DynamicalSystem, but we got {type(target)}.')
self.target = target # the target object to slice

# check slicing
if isinstance(index, (int, slice)):
index = (index,)
self.index = index # the slice

# get all variables for slicing
if not hasattr(self.target, SLICE_VARS):
if varshape is None:
if isinstance(target, NeuGroup):
varshape = target.varshape
else:
raise UnsupportedError('Should provide varshape when the target does '
f'not define its {SLICE_VARS}')
all_vars = target.vars(level=1, include_self=True, method='relative')
all_vars = {k: v for k, v in all_vars.items() if v.shape_nb == varshape}
else:
all_vars = {}
for var_str in getattr(self.target, SLICE_VARS):
v = eval(f'target.{var_str}')
all_vars[var_str] = v

# slice variables
self.slice_vars = dict()
for k, v in all_vars.items():
if v.batch_axis is not None:
index = ((self.index[:v.batch_axis] +
(slice(None, None, None), ) +
self.index[v.batch_axis:])
if len(self.index) > v.batch_axis else
(self.index + tuple([slice(None, None, None)
for _ in range(v.batch_axis - len(self.index) + 1)])))
else:
index = self.index
self.slice_vars[k] = bm.VariableView(v, index)

# sub-nodes
nodes = target.nodes(method='relative', level=1, include_self=False).subset(DynamicalSystem)
for k, node in nodes.items():
if isinstance(node, NeuGroup):
node = NeuGroupView(node, self.index)
else:
node = DSView(node, self.index, varshape)
setattr(self, k, node)

def __repr__(self):
children = [f'{key}={val.__name__}' for key, val in self.children_cls.items()]
return f'{self.master.__name__}({", ".join(children)})'
return f'{self.__class__.__name__}(target={self.target}, index={self.index})'

def __getattribute__(self, item):
try:
slice_vars = object.__getattribute__(self, 'slice_vars')
if item in slice_vars:
value = slice_vars[item]
return value
return object.__getattribute__(self, item)
except AttributeError:
return object.__getattribute__(self, item)

def __setattr__(self, key, value):
if hasattr(self, 'slice_vars'):
slice_vars = super(DSView, self).__getattribute__('slice_vars')
if key in slice_vars:
v = slice_vars[key]
v.value = value
return
super(DSView, self).__setattr__(key, value)

def update(self, *args, **kwargs):
raise NoImplementationError(f'DSView {self} cannot be updated. Please update its parent {self.target}')

def reset_state(self, batch_size=None):
pass


@numba_jit
def _slice_to_num(slice_: slice, length: int):
# start
start = slice_.start
if start is None:
start = 0
if start < 0:
start = length + start
start = max(start, 0)
# stop
stop = slice_.stop
if stop is None:
stop = length
if stop < 0:
stop = length + stop
stop = min(stop, length)
# step
step = slice_.step
if step is None:
step = 1
# number
num = 0
while start < stop:
start += step
num += 1
return num


class NeuGroupView(DSView, NeuGroup):
"""A view for a neuron group instance."""

def __init__(
self,
target: NeuGroup,
index: Union[slice, Sequence, Array],
name: str = None,
mode: Mode = None
):
DSView.__init__(self, target, index)

# check slicing
var_shapes = target.varshape
if len(self.index) > len(var_shapes):
raise ValueError(f"Length of the index should be less than "
f"that of the target's varshape. But we "
f"got {len(self.index)} > {len(var_shapes)}")

# get size
size = []
for i, idx in enumerate(self.index):
if isinstance(idx, int):
size.append(1)
elif isinstance(idx, slice):
size.append(_slice_to_num(idx, var_shapes[i]))
else:
# should be a list/tuple/array of int
# do not check again
if not isinstance(idx, collections.Iterable):
raise TypeError('Should be an iterable object of int.')
size.append(len(idx))
size += list(var_shapes[len(self.index):])

# initialization
NeuGroup.__init__(self, tuple(size), name=name, mode=mode)

+ 1093
- 0
brainpy/dyn/channels/Ca.py View File

@@ -0,0 +1,1093 @@
# -*- coding: utf-8 -*-

"""
This module implements voltage-dependent calcium channels.

"""

from typing import Union, Callable

import brainpy.math as bm
from brainpy.dyn.base import Channel
from brainpy.initialize import OneInit, Initializer, parameter, variable
from brainpy.integrators.joint_eq import JointEq
from brainpy.integrators.ode import odeint
from brainpy.types import Shape, Array
from brainpy.modes import Mode, BatchingMode, normal
from .base import Calcium, CalciumChannel

__all__ = [
'CalciumFixed',
'CalciumDyna',
'CalciumDetailed',
'CalciumFirstOrder',

'ICa_p2q_ss', 'ICa_p2q_markov',

'ICaN_IS2008',

'ICaT_HM1992',
'ICaT_HP1992',

'ICaHT_HM1992',

'ICaL_IS2008',
]


class CalciumFixed(Calcium):
"""Fixed Calcium dynamics.

This calcium model has no dynamics. It holds fixed reversal
potential :math:`E` and concentration :math:`C`.
"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
E: Union[float, Array, Initializer, Callable] = 120.,
C: Union[float, Array, Initializer, Callable] = 2.4e-4,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
**channels
):
super(CalciumFixed, self).__init__(size,
keep_size=keep_size,
method=method,
name=name,
mode=mode,
**channels)
self.E = parameter(E, self.varshape, allow_none=False)
self.C = parameter(C, self.varshape, allow_none=False)

def update(self, tdi, V):
for node in self.implicit_nodes.values():
node.update(tdi, V, self.C, self.E)

def reset_state(self, V, C_Ca=None, E_Ca=None, batch_size=None):
C_Ca = self.C if C_Ca is None else C_Ca
E_Ca = self.E if E_Ca is None else E_Ca
for node in self.nodes(level=1, include_self=False).unique().subset(Channel).values():
node.reset_state(V, C_Ca, E_Ca, batch_size=batch_size)


class CalciumDyna(Calcium):
"""Calcium ion flow with dynamics.

Parameters
----------
size: int, tuple of int
The ion size.
keep_size: bool
Keep the geometry size.
C0: float, Array, Initializer, Callable
The Calcium concentration outside of membrane.
T: float, Array, Initializer, Callable
The temperature.
C_initializer: Initializer, Callable, Array
The initializer for Calcium concentration.
method: str
The numerical method.
name: str
The ion name.
"""
R = 8.31441 # gas constant, J*mol-1*K-1
F = 96.489 # the Faraday constant

def __init__(
self,
size: Shape,
keep_size: bool = False,
C0: Union[float, Array, Initializer, Callable] = 2.,
T: Union[float, Array, Initializer, Callable] = 36.,
C_initializer: Union[Initializer, Callable, Array] = OneInit(2.4e-4),
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
**channels
):
super(CalciumDyna, self).__init__(size,
keep_size=keep_size,
method=method,
name=name,
mode=mode,
**channels)

# parameters
self.C0 = parameter(C0, self.varshape, allow_none=False)
self.T = parameter(T, self.varshape, allow_none=False) # temperature
self._C_initializer = C_initializer
self._constant = self.R / (2 * self.F) * (273.15 + self.T)

# variables
self.C = variable(C_initializer, mode, self.varshape) # Calcium concentration
self.E = bm.Variable(self._reversal_potential(self.C),
batch_axis=0 if isinstance(mode, BatchingMode) else None) # Reversal potential

# function
self.integral = odeint(self.derivative, method=method)

def derivative(self, C, t, V):
raise NotImplementedError

def reset_state(self, V, C_Ca=None, E_Ca=None, batch_size=None):
self.C.value = variable(self._C_initializer, batch_size, self.varshape) if (C_Ca is None) else C_Ca
self.E.value = self._reversal_potential(self.C)
for node in self.nodes(level=1, include_self=False).unique().subset(Channel).values():
node.reset(V, self.C, self.E, batch_size=batch_size)

def update(self, tdi, V):
for node in self.nodes(level=1, include_self=False).unique().subset(Channel).values():
node.update(tdi, V, self.C, self.E)
self.C.value = self.integral(self.C.value, tdi['t'], V, tdi['dt'])
self.E.value = self._reversal_potential(self.C)

def _reversal_potential(self, C):
return self._constant * bm.log(self.C0 / C)


class CalciumDetailed(CalciumDyna):
r"""Dynamical Calcium model proposed.

**1. The dynamics of intracellular** :math:`Ca^{2+}`

The dynamics of intracellular :math:`Ca^{2+}` were determined by two contributions [1]_ :

*(i) Influx of* :math:`Ca^{2+}` *due to Calcium currents*

:math:`Ca^{2+}` ions enter through :math:`Ca^{2+}` channels and diffuse into the
interior of the cell. Only the :math:`Ca^{2+}` concentration in a thin shell beneath
the membrane was modeled. The influx of :math:`Ca^{2+}` into such a thin shell followed:

.. math::

[Ca]_{i}=-\frac{k}{2 F d} I_{Ca}

where :math:`F=96489\, \mathrm{C\, mol^{-1}}` is the Faraday constant,
:math:`d=1\, \mathrm{\mu m}` is the depth of the shell beneath the membrane,
the unit conversion constant is :math:`k=0.1` for :math:`I_T` in
:math:`\mathrm{\mu A/cm^{2}}` and :math:`[Ca]_{i}` in millimolar,
and :math:`I_{Ca}` is the summation of all :math:`Ca^{2+}` currents.

*(ii) Efflux of* :math:`Ca^{2+}` *due to an active pump*

In a thin shell beneath the membrane, :math:`Ca^{2+}` retrieval usually consists of a
combination of several processes, such as binding to :math:`Ca^{2+}` buffers, calcium
efflux due to :math:`Ca^{2+}` ATPase pump activity and diffusion to neighboring shells.
Only the :math:`Ca^{2+}` pump was modeled here. We adopted the following kinetic scheme:

.. math::

Ca _{i}^{2+}+ P \overset{c_1}{\underset{c_2}{\rightleftharpoons}} CaP \xrightarrow{c_3} P+ Ca _{0}^{2+}

where P represents the :math:`Ca^{2+}` pump, CaP is an intermediate state,
:math:`Ca _{ o }^{2+}` is the extracellular :math:`Ca^{2+}` concentration,
and :math:`c_{1}, c_{2}` and :math:`c_{3}` are rate constants. :math:`Ca^{2+}`
ions have a high affinity for the pump :math:`P`, whereas extrusion of
:math:`Ca^{2+}` follows a slower process (Blaustein, 1988 ). Therefore,
:math:`c_{3}` is low compared to :math:`c_{1}` and :math:`c_{2}` and the
Michaelis-Menten approximation can be used for describing the kinetics of the pump.
According to such a scheme, the kinetic equation for the :math:`Ca^{2+}` pump is:

.. math::

\frac{[Ca^{2+}]_{i}}{dt}=-\frac{K_{T}[Ca]_{i}}{[Ca]_{i}+K_{d}}

where :math:`K_{T}=10^{-4}\, \mathrm{mM\, ms^{-1}}` is the product of :math:`c_{3}`
with the total concentration of :math:`P` and :math:`K_{d}=c_{2} / c_{1}=10^{-4}\, \mathrm{mM}`
is the dissociation constant, which can be interpreted here as the value of
:math:`[Ca]_{i}` at which the pump is half activated (if :math:`[Ca]_{i} \ll K_{d}`
then the efflux is negligible).

**2.A simple first-order model**

While, in (Bazhenov, et al., 1998) [2]_, the :math:`Ca^{2+}` dynamics is
described by a simple first-order model,

.. math::

\frac{d\left[Ca^{2+}\right]_{i}}{d t}=-\frac{I_{Ca}}{z F d}+\frac{\left[Ca^{2+}\right]_{rest}-\left[C a^{2+}\right]_{i}}{\tau_{Ca}}

where :math:`I_{Ca}` is the summation of all :math:`Ca ^{2+}` currents, :math:`d`
is the thickness of the perimembrane "shell" in which calcium is able to affect
membrane properties :math:`(1.\, \mathrm{\mu M})`, :math:`z=2` is the valence of the
:math:`Ca ^{2+}` ion, :math:`F` is the Faraday constant, and :math:`\tau_{C a}` is
the :math:`Ca ^{2+}` removal rate. The resting :math:`Ca ^{2+}` concentration was
set to be :math:`\left[ Ca ^{2+}\right]_{\text {rest}}=.05\, \mathrm{\mu M}` .

**3. The reversal potential**

The reversal potential of calcium :math:`Ca ^{2+}` is calculated according to the
Nernst equation:

.. math::

E = k'{RT \over 2F} log{[Ca^{2+}]_0 \over [Ca^{2+}]_i}

where :math:`R=8.31441 \, \mathrm{J} /(\mathrm{mol}^{\circ} \mathrm{K})`,
:math:`T=309.15^{\circ} \mathrm{K}`,
:math:`F=96,489 \mathrm{C} / \mathrm{mol}`,
and :math:`\left[\mathrm{Ca}^{2+}\right]_{0}=2 \mathrm{mM}`.

Parameters
----------
d : float
The thickness of the peri-membrane "shell".
F : float
The Faraday constant. (:math:`C*mmol^{-1}`)
tau : float
The time constant of the :math:`Ca ^{2+}` removal rate. (ms)
C_rest : float
The resting :math:`Ca ^{2+}` concentration.
C0 : float
The :math:`Ca ^{2+}` concentration outside of the membrane.
R : float
The gas constant. (:math:` J*mol^{-1}*K^{-1}`)

References
----------

.. [1] Destexhe, Alain, Agnessa Babloyantz, and Terrence J. Sejnowski.
"Ionic mechanisms for intrinsic slow oscillations in thalamic
relay neurons." Biophysical journal 65, no. 4 (1993): 1538-1552.
.. [2] Bazhenov, Maxim, Igor Timofeev, Mircea Steriade, and Terrence J.
Sejnowski. "Cellular and network models for intrathalamic augmenting
responses during 10-Hz stimulation." Journal of neurophysiology 79,
no. 5 (1998): 2730-2748.

"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
T: Union[float, Array, Initializer, Callable] = 36.,
d: Union[float, Array, Initializer, Callable] = 1.,
C_rest: Union[float, Array, Initializer, Callable] = 2.4e-4,
tau: Union[float, Array, Initializer, Callable] = 5.,
C0: Union[float, Array, Initializer, Callable] = 2.,
C_initializer: Union[Initializer, Callable, Array] = OneInit(2.4e-4),
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
**channels
):
super(CalciumDetailed, self).__init__(size,
keep_size=keep_size,
method=method,
name=name,
T=T,
C0=C0,
C_initializer=C_initializer,
mode=mode,
**channels)

# parameters
self.d = parameter(d, self.varshape, allow_none=False)
self.tau = parameter(tau, self.varshape, allow_none=False)
self.C_rest = parameter(C_rest, self.varshape, allow_none=False)

def derivative(self, C, t, V):
ICa = self.current(V, C, self.E)
drive = bm.maximum(- ICa / (2 * self.F * self.d), 0.)
return drive + (self.C_rest - C) / self.tau


class CalciumFirstOrder(CalciumDyna):
r"""The first-order calcium concentration model.

.. math::

Ca' = -\alpha I_{Ca} + -\beta Ca

"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
T: Union[float, Array, Initializer, Callable] = 36.,
alpha: Union[float, Array, Initializer, Callable] = 0.13,
beta: Union[float, Array, Initializer, Callable] = 0.075,
C0: Union[float, Array, Initializer, Callable] = 2.,
C_initializer: Union[Initializer, Callable, Array] = OneInit(2.4e-4),
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
**channels
):
super(CalciumFirstOrder, self).__init__(size,
keep_size=keep_size,
method=method,
name=name,
T=T,
C0=C0,
C_initializer=C_initializer,
mode=mode,
**channels)

# parameters
self.alpha = parameter(alpha, self.varshape, allow_none=False)
self.beta = parameter(beta, self.varshape, allow_none=False)

def derivative(self, C, t, V):
ICa = self.current(V, C, self.E)
drive = bm.maximum(- self.alpha * ICa, 0.)
return drive - self.beta * C


# -------------------------


class ICa_p2q_ss(CalciumChannel):
r"""The calcium current model of :math:`p^2q` current which described with steady-state format.

The dynamics of this generalized calcium current model is given by:

.. math::

I_{CaT} &= g_{max} p^2 q(V-E_{Ca}) \\
{dp \over dt} &= {\phi_p \cdot (p_{\infty}-p)\over \tau_p} \\
{dq \over dt} &= {\phi_q \cdot (q_{\infty}-q) \over \tau_q} \\

where :math:`\phi_p` and :math:`\phi_q` are temperature-dependent factors,
:math:`E_{Ca}` is the reversal potential of Calcium channel.

Parameters
----------
size: int, tuple of int
The size of the simulation target.
keep_size: bool
Keep size or flatten the size?
method: str
The numerical method
name: str
The name of the object.
g_max : float, Array, Callable, Initializer
The maximum conductance.
phi_p : float, Array, Callable, Initializer
The temperature factor for channel :math:`p`.
phi_q : float, Array, Callable, Initializer
The temperature factor for channel :math:`q`.

"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
phi_p: Union[float, Array, Initializer, Callable] = 3.,
phi_q: Union[float, Array, Initializer, Callable] = 3.,
g_max: Union[float, Array, Initializer, Callable] = 2.,
method: str = 'exp_auto',
mode: Mode = normal,
name: str = None
):
super(ICa_p2q_ss, self).__init__(size,
keep_size=keep_size,
name=name,
mode=mode, )

# parameters
self.phi_p = parameter(phi_p, self.varshape, allow_none=False)
self.phi_q = parameter(phi_q, self.varshape, allow_none=False)
self.g_max = parameter(g_max, self.varshape, allow_none=False)

# variables
self.p = variable(bm.zeros, mode, self.varshape)
self.q = variable(bm.zeros, mode, self.varshape)

# functions
self.integral = odeint(JointEq([self.dp, self.dq]), method=method)

def dp(self, p, t, V):
return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V)

def dq(self, q, t, V):
return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V)

def update(self, tdi, V, C_Ca, E_Ca):
self.p.value, self.q.value = self.integral(self.p, self.q, tdi['t'], V, tdi['dt'])

def current(self, V, C_Ca, E_Ca):
return self.g_max * self.p * self.p * self.q * (E_Ca - V)

def reset_state(self, V, C_Ca, E_Ca, batch_size=None):
self.p.value = self.f_p_inf(V)
self.q.value = self.f_q_inf(V)
if batch_size is not None:
assert self.p.shape[0] == batch_size
assert self.q.shape[0] == batch_size

def f_p_inf(self, V):
raise NotImplementedError

def f_p_tau(self, V):
raise NotImplementedError

def f_q_inf(self, V):
raise NotImplementedError

def f_q_tau(self, V):
raise NotImplementedError


class ICa_p2q_markov(CalciumChannel):
r"""The calcium current model of :math:`p^2q` current which described with first-order Markov chain.

The dynamics of this generalized calcium current model is given by:

.. math::

I_{CaT} &= g_{max} p^2 q(V-E_{Ca}) \\
{dp \over dt} &= \phi_p (\alpha_p(V)(1-p) - \beta_p(V)p) \\
{dq \over dt} &= \phi_q (\alpha_q(V)(1-q) - \beta_q(V)q) \\

where :math:`\phi_p` and :math:`\phi_q` are temperature-dependent factors,
:math:`E_{Ca}` is the reversal potential of Calcium channel.

Parameters
----------
size: int, tuple of int
The size of the simulation target.
keep_size: bool
Keep size or flatten the size?
method: str
The numerical method
name: str
The name of the object.
g_max : float, Array, Callable, Initializer
The maximum conductance.
phi_p : float, Array, Callable, Initializer
The temperature factor for channel :math:`p`.
phi_q : float, Array, Callable, Initializer
The temperature factor for channel :math:`q`.

"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
phi_p: Union[float, Array, Initializer, Callable] = 3.,
phi_q: Union[float, Array, Initializer, Callable] = 3.,
g_max: Union[float, Array, Initializer, Callable] = 2.,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
super(ICa_p2q_markov, self).__init__(size,
keep_size=keep_size,
name=name,
mode=mode)

# parameters
self.phi_p = parameter(phi_p, self.varshape, allow_none=False)
self.phi_q = parameter(phi_q, self.varshape, allow_none=False)
self.g_max = parameter(g_max, self.varshape, allow_none=False)

# variables
self.p = variable(bm.zeros, mode, self.varshape)
self.q = variable(bm.zeros, mode, self.varshape)

# functions
self.integral = odeint(JointEq([self.dp, self.dq]), method=method)

def dp(self, p, t, V):
return self.phi_p * (self.f_p_alpha(V) * (1 - p) - self.f_p_beta(V) * p)

def dq(self, q, t, V):
return self.phi_q * (self.f_q_alpha(V) * (1 - q) - self.f_q_beta(V) * q)

def update(self, tdi, V, C_Ca, E_Ca):
self.p.value, self.q.value = self.integral(self.p, self.q, tdi['t'], V, tdi['dt'])

def current(self, V, C_Ca, E_Ca):
return self.g_max * self.p * self.p * self.q * (E_Ca - V)

def reset_state(self, V, C_Ca, E_Ca, batch_size=None):
alpha, beta = self.f_p_alpha(V), self.f_p_beta(V)
self.p.value = alpha / (alpha + beta)
alpha, beta = self.f_q_alpha(V), self.f_q_beta(V)
self.q.value = alpha / (alpha + beta)
if batch_size is not None:
assert self.p.shape[0] == batch_size
assert self.q.shape[0] == batch_size

def f_p_alpha(self, V):
raise NotImplementedError

def f_p_beta(self, V):
raise NotImplementedError

def f_q_alpha(self, V):
raise NotImplementedError

def f_q_beta(self, V):
raise NotImplementedError


class ICaN_IS2008(CalciumChannel):
r"""The calcium-activated non-selective cation channel model
proposed by (Inoue & Strowbridge, 2008) [2]_.

The dynamics of the calcium-activated non-selective cation channel model [1]_ [2]_ is given by:

.. math::

\begin{aligned}
I_{CAN} &=g_{\mathrm{max}} M\left([Ca^{2+}]_{i}\right) p \left(V-E\right)\\
&M\left([Ca^{2+}]_{i}\right) ={[Ca^{2+}]_{i} \over 0.2+[Ca^{2+}]_{i}} \\
&{dp \over dt} = {\phi \cdot (p_{\infty}-p)\over \tau_p} \\
&p_{\infty} = {1.0 \over 1 + \exp(-(V + 43) / 5.2)} \\
&\tau_{p} = {2.7 \over \exp(-(V + 55) / 15) + \exp((V + 55) / 15)} + 1.6
\end{aligned}

where :math:`\phi` is the temperature factor.

Parameters
----------
g_max : float
The maximal conductance density (:math:`mS/cm^2`).
E : float
The reversal potential (mV).
phi : float
The temperature factor.

References
----------

.. [1] Destexhe, Alain, et al. "A model of spindle rhythmicity in the isolated
thalamic reticular nucleus." Journal of neurophysiology 72.2 (1994): 803-818.
.. [2] Inoue T, Strowbridge BW (2008) Transient activity induces a long-lasting
increase in the excitability of olfactory bulb interneurons.
J Neurophysiol 99: 187–199.
"""

'''The type of the master object.'''
master_type = CalciumDyna

def __init__(
self,
size: Shape,
keep_size: bool = False,
E: Union[float, Array, Initializer, Callable] = 10.,
g_max: Union[float, Array, Initializer, Callable] = 1.,
phi: Union[float, Array, Initializer, Callable] = 1.,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
super(ICaN_IS2008, self).__init__(size,
keep_size=keep_size,
name=name,
mode=mode)

# parameters
self.E = parameter(E, self.varshape, allow_none=False)
self.g_max = parameter(g_max, self.varshape, allow_none=False)
self.phi = parameter(phi, self.varshape, allow_none=False)

# variables
self.p = variable(bm.zeros, mode, self.varshape)

# function
self.integral = odeint(self.derivative, method=method)

def derivative(self, p, t, V):
phi_p = 1.0 / (1 + bm.exp(-(V + 43.) / 5.2))
p_inf = 2.7 / (bm.exp(-(V + 55.) / 15.) + bm.exp((V + 55.) / 15.)) + 1.6
return self.phi * (phi_p - p) / p_inf

def update(self, tdi, V, C_Ca, E_Ca):
self.p.value = self.integral(self.p, tdi['t'], V, tdi['dt'])

def current(self, V, C_Ca, E_Ca):
M = C_Ca / (C_Ca + 0.2)
g = self.g_max * M * self.p
return g * (self.E - V)

def reset_state(self, V, C_Ca, E_Ca, batch_size=None):
self.p.value = 1.0 / (1 + bm.exp(-(V + 43.) / 5.2))
if batch_size is not None:
assert self.p.shape[0] == batch_size


class ICaT_HM1992(ICa_p2q_ss):
r"""The low-threshold T-type calcium current model proposed by (Huguenard & McCormick, 1992) [1]_.

The dynamics of the low-threshold T-type calcium current model [1]_ is given by:

.. math::

I_{CaT} &= g_{max} p^2 q(V-E_{Ca}) \\
{dp \over dt} &= {\phi_p \cdot (p_{\infty}-p)\over \tau_p} \\
&p_{\infty} = {1 \over 1+\exp [-(V+59-V_{sh}) / 6.2]} \\
&\tau_{p} = 0.612 + {1 \over \exp [-(V+132.-V_{sh}) / 16.7]+\exp [(V+16.8-V_{sh}) / 18.2]} \\
{dq \over dt} &= {\phi_q \cdot (q_{\infty}-q) \over \tau_q} \\
&q_{\infty} = {1 \over 1+\exp [(V+83-V_{sh}) / 4]} \\
& \begin{array}{l} \tau_{q} = \exp \left(\frac{V+467-V_{sh}}{66.6}\right) \quad V< (-80 +V_{sh})\, mV \\
\tau_{q} = \exp \left(\frac{V+22-V_{sh}}{-10.5}\right)+28 \quad V \geq (-80 + V_{sh})\, mV \end{array}

where :math:`\phi_p = 3.55^{\frac{T-24}{10}}` and :math:`\phi_q = 3^{\frac{T-24}{10}}`
are temperature-dependent factors (:math:`T` is the temperature in Celsius),
:math:`E_{Ca}` is the reversal potential of Calcium channel.

Parameters
----------
T : float, Array
The temperature.
T_base_p : float, Array
The base temperature factor of :math:`p` channel.
T_base_q : float, Array
The base temperature factor of :math:`q` channel.
g_max : float, Array, Callable, Initializer
The maximum conductance.
V_sh : float, Array, Callable, Initializer
The membrane potential shift.
phi_p : optional, float, Array, Callable, Initializer
The temperature factor for channel :math:`p`.
phi_q : optional, float, Array, Callable, Initializer
The temperature factor for channel :math:`q`.

References
----------

.. [1] Huguenard JR, McCormick DA (1992) Simulation of the currents involved in
rhythmic oscillations in thalamic relay neurons. J Neurophysiol 68:1373–1383.

See Also
--------
ICa_p2q_form
"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
T: Union[float, Array] = 36.,
T_base_p: Union[float, Array] = 3.55,
T_base_q: Union[float, Array] = 3.,
g_max: Union[float, Array, Initializer, Callable] = 2.,
V_sh: Union[float, Array, Initializer, Callable] = -3.,
phi_p: Union[float, Array, Initializer, Callable] = None,
phi_q: Union[float, Array, Initializer, Callable] = None,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
phi_p = T_base_p ** ((T - 24) / 10) if phi_p is None else phi_p
phi_q = T_base_q ** ((T - 24) / 10) if phi_q is None else phi_q
super(ICaT_HM1992, self).__init__(size,
keep_size=keep_size,
name=name,
method=method,
g_max=g_max,
phi_p=phi_p,
phi_q=phi_q,
mode=mode)

# parameters
self.T = parameter(T, self.varshape, allow_none=False)
self.T_base_p = parameter(T_base_p, self.varshape, allow_none=False)
self.T_base_q = parameter(T_base_q, self.varshape, allow_none=False)
self.V_sh = parameter(V_sh, self.varshape, allow_none=False)

def f_p_inf(self, V):
return 1. / (1 + bm.exp(-(V + 59. - self.V_sh) / 6.2))

def f_p_tau(self, V):
return 1. / (bm.exp(-(V + 132. - self.V_sh) / 16.7) +
bm.exp((V + 16.8 - self.V_sh) / 18.2)) + 0.612

def f_q_inf(self, V):
return 1. / (1. + bm.exp((V + 83. - self.V_sh) / 4.0))

def f_q_tau(self, V):
return bm.where(V >= (-80. + self.V_sh),
bm.exp(-(V + 22. - self.V_sh) / 10.5) + 28.,
bm.exp((V + 467. - self.V_sh) / 66.6))


class ICaT_HP1992(ICa_p2q_ss):
r"""The low-threshold T-type calcium current model for thalamic
reticular nucleus proposed by (Huguenard & Prince, 1992) [1]_.

The dynamics of the low-threshold T-type calcium current model in thalamic
reticular nucleus neurons [1]_ is given by:

.. math::

I_{CaT} &= g_{max} p^2 q(V-E_{Ca}) \\
{dp \over dt} &= {\phi_p \cdot (p_{\infty}-p)\over \tau_p} \\
&p_{\infty} = {1 \over 1+\exp [-(V+52-V_{sh}) / 7.4]} \\
&\tau_{p} = 3+{1 \over \exp [(V+27-V_{sh}) / 10]+\exp [-(V+102-V_{sh}) / 15]} \\
{dq \over dt} &= {\phi_q \cdot (q_{\infty}-q) \over \tau_q} \\
&q_{\infty} = {1 \over 1+\exp [(V+80-V_{sh}) / 5]} \\
& \tau_q = 85+ {1 \over \exp [(V+48-V_{sh}) / 4]+\exp [-(V+407-V_{sh}) / 50]}

where :math:`\phi_p = 5^{\frac{T-24}{10}}` and :math:`\phi_q = 3^{\frac{T-24}{10}}`
are temperature-dependent factors (:math:`T` is the temperature in Celsius),
:math:`E_{Ca}` is the reversal potential of Calcium channel.

Parameters
----------
T : float, Array
The temperature.
T_base_p : float, Array
The base temperature factor of :math:`p` channel.
T_base_q : float, Array
The base temperature factor of :math:`q` channel.
g_max : float, Array, Callable, Initializer
The maximum conductance.
V_sh : float, Array, Callable, Initializer
The membrane potential shift.
phi_p : optional, float, Array, Callable, Initializer
The temperature factor for channel :math:`p`.
phi_q : optional, float, Array, Callable, Initializer
The temperature factor for channel :math:`q`.

References
----------

.. [1] Huguenard JR, Prince DA (1992) A novel T-type current underlies
prolonged Ca2+- dependent burst firing in GABAergic neurons of rat
thalamic reticular nucleus. J Neurosci 12: 3804–3817.

See Also
--------
ICa_p2q_form
"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
T: Union[float, Array] = 36.,
T_base_p: Union[float, Array] = 5.,
T_base_q: Union[float, Array] = 3.,
g_max: Union[float, Array, Initializer, Callable] = 1.75,
V_sh: Union[float, Array, Initializer, Callable] = -3.,
phi_p: Union[float, Array, Initializer, Callable] = None,
phi_q: Union[float, Array, Initializer, Callable] = None,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
phi_p = T_base_p ** ((T - 24) / 10) if phi_p is None else phi_p
phi_q = T_base_q ** ((T - 24) / 10) if phi_q is None else phi_q
super(ICaT_HP1992, self).__init__(size,
keep_size=keep_size,
name=name,
method=method,
g_max=g_max,
phi_p=phi_p,
phi_q=phi_q,
mode=mode)

# parameters
self.T = parameter(T, self.varshape, allow_none=False)
self.T_base_p = parameter(T_base_p, self.varshape, allow_none=False)
self.T_base_q = parameter(T_base_q, self.varshape, allow_none=False)
self.V_sh = parameter(V_sh, self.varshape, allow_none=False)

def f_p_inf(self, V):
return 1. / (1. + bm.exp(-(V + 52. - self.V_sh) / 7.4))

def f_p_tau(self, V):
return 3. + 1. / (bm.exp((V + 27. - self.V_sh) / 10.) +
bm.exp(-(V + 102. - self.V_sh) / 15.))

def f_q_inf(self, V):
return 1. / (1. + bm.exp((V + 80. - self.V_sh) / 5.))

def f_q_tau(self, V):
return 85. + 1. / (bm.exp((V + 48. - self.V_sh) / 4.) +
bm.exp(-(V + 407. - self.V_sh) / 50.))


class ICaHT_HM1992(ICa_p2q_ss):
r"""The high-threshold T-type calcium current model proposed by (Huguenard & McCormick, 1992) [1]_.

The high-threshold T-type calcium current model is adopted from [1]_.
Its dynamics is given by

.. math::

\begin{aligned}
I_{\mathrm{Ca/HT}} &= g_{\mathrm{max}} p^2 q (V-E_{Ca})
\\
{dp \over dt} &= {\phi_{p} \cdot (p_{\infty} - p) \over \tau_{p}} \\
&\tau_{p} =\frac{1}{\exp \left(\frac{V+132-V_{sh}}{-16.7}\right)+\exp \left(\frac{V+16.8-V_{sh}}{18.2}\right)}+0.612 \\
& p_{\infty} = {1 \over 1+exp[-(V+59-V_{sh}) / 6.2]}
\\
{dq \over dt} &= {\phi_{q} \cdot (q_{\infty} - h) \over \tau_{q}} \\
& \begin{array}{l} \tau_q = \exp \left(\frac{V+467-V_{sh}}{66.6}\right) \quad V< (-80 +V_{sh})\, mV \\
\tau_q = \exp \left(\frac{V+22-V_{sh}}{-10.5}\right)+28 \quad V \geq (-80 + V_{sh})\, mV \end{array} \\
&q_{\infty} = {1 \over 1+exp[(V+83 -V_{shift})/4]}
\end{aligned}

where :math:`phi_p = 3.55^{\frac{T-24}{10}}` and :math:`phi_q = 3^{\frac{T-24}{10}}`
are temperature-dependent factors (:math:`T` is the temperature in Celsius),
:math:`E_{Ca}` is the reversal potential of Calcium channel.

Parameters
----------
T : float, Array
The temperature.
T_base_p : float, Array
The base temperature factor of :math:`p` channel.
T_base_q : float, Array
The base temperature factor of :math:`q` channel.
g_max : float, Array, Initializer, Callable
The maximum conductance.
V_sh : float, Array, Initializer, Callable
The membrane potential shift.

References
----------
.. [1] Huguenard JR, McCormick DA (1992) Simulation of the currents involved in
rhythmic oscillations in thalamic relay neurons. J Neurophysiol 68:1373–1383.

See Also
--------
ICa_p2q_form
"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
T: Union[float, Array] = 36.,
T_base_p: Union[float, Array] = 3.55,
T_base_q: Union[float, Array] = 3.,
g_max: Union[float, Array, Initializer, Callable] = 2.,
V_sh: Union[float, Array, Initializer, Callable] = 25.,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
super(ICaHT_HM1992, self).__init__(size,
keep_size=keep_size,
name=name,
method=method,
g_max=g_max,
phi_p=T_base_p ** ((T - 24) / 10),
phi_q=T_base_q ** ((T - 24) / 10),
mode=mode)

# parameters
self.T = parameter(T, self.varshape, allow_none=False)
self.T_base_p = parameter(T_base_p, self.varshape, allow_none=False)
self.T_base_q = parameter(T_base_q, self.varshape, allow_none=False)
self.V_sh = parameter(V_sh, self.varshape, allow_none=False)

# variables
self.p = variable(bm.zeros, mode, self.varshape)
self.q = variable(bm.zeros, mode, self.varshape)

# function
self.integral = odeint(JointEq([self.dp, self.dq]), method=method)

def f_p_inf(self, V):
return 1. / (1. + bm.exp(-(V + 59. - self.V_sh) / 6.2))

def f_p_tau(self, V):
return 1. / (bm.exp(-(V + 132. - self.V_sh) / 16.7) +
bm.exp((V + 16.8 - self.V_sh) / 18.2)) + 0.612

def f_q_inf(self, V):
return 1. / (1. + bm.exp((V + 83. - self.V_sh) / 4.))

def f_q_tau(self, V):
return bm.where(V >= (-80. + self.V_sh),
bm.exp(-(V + 22. - self.V_sh) / 10.5) + 28.,
bm.exp((V + 467. - self.V_sh) / 66.6))


class ICaHT_Re1993(ICa_p2q_markov):
r"""The high-threshold T-type calcium current model proposed by (Reuveni, et al., 1993) [1]_.

HVA Calcium current was described for neocortical neurons by Sayer et al. (1990).
Its dynamics is given by (the rate functions are measured under 36 Celsius):

.. math::

\begin{aligned}
I_{L} &=\bar{g}_{L} q^{2} r\left(V-E_{\mathrm{Ca}}\right) \\
\frac{\mathrm{d} q}{\mathrm{~d} t} &= \phi_p (\alpha_{q}(V)(1-q)-\beta_{q}(V) q) \\
\frac{\mathrm{d} r}{\mathrm{~d} t} &= \phi_q (\alpha_{r}(V)(1-r)-\beta_{r}(V) r) \\
\alpha_{q} &=\frac{0.055(-27-V+V_{sh})}{\exp [(-27-V+V_{sh}) / 3.8]-1} \\
\beta_{q} &=0.94 \exp [(-75-V+V_{sh}) / 17] \\
\alpha_{r} &=0.000457 \exp [(-13-V+V_{sh}) / 50] \\
\beta_{r} &=\frac{0.0065}{\exp [(-15-V+V_{sh}) / 28]+1},
\end{aligned}

Parameters
----------
size: int, tuple of int
The size of the simulation target.
keep_size: bool
Keep size or flatten the size?
method: str
The numerical method
name: str
The name of the object.
g_max : float, Array, Callable, Initializer
The maximum conductance.
V_sh : float, Array, Callable, Initializer
The membrane potential shift.
T : float, Array
The temperature.
T_base_p : float, Array
The base temperature factor of :math:`p` channel.
T_base_q : float, Array
The base temperature factor of :math:`q` channel.
phi_p : optional, float, Array, Callable, Initializer
The temperature factor for channel :math:`p`.
If `None`, :math:`\phi_p = \mathrm{T_base_p}^{\frac{T-23}{10}}`.
phi_q : optional, float, Array, Callable, Initializer
The temperature factor for channel :math:`q`.
If `None`, :math:`\phi_q = \mathrm{T_base_q}^{\frac{T-23}{10}}`.

References
----------
.. [1] Reuveni, I., et al. "Stepwise repolarization from Ca2+ plateaus
in neocortical pyramidal cells: evidence for nonhomogeneous
distribution of HVA Ca2+ channels in dendrites." Journal of
Neuroscience 13.11 (1993): 4609-4621.

"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
T: Union[float, Array] = 36.,
T_base_p: Union[float, Array] = 2.3,
T_base_q: Union[float, Array] = 2.3,
phi_p: Union[float, Array, Initializer, Callable] = None,
phi_q: Union[float, Array, Initializer, Callable] = None,
g_max: Union[float, Array, Initializer, Callable] = 1.,
V_sh: Union[float, Array, Initializer, Callable] = 0.,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
phi_p = T_base_p ** ((T - 23.) / 10.) if phi_p is None else phi_p
phi_q = T_base_q ** ((T - 23.) / 10.) if phi_q is None else phi_q
super(ICaHT_Re1993, self).__init__(size,
keep_size=keep_size,
name=name,
method=method,
g_max=g_max,
phi_p=phi_p,
phi_q=phi_q,
mode=mode)
self.T = parameter(T, self.varshape, allow_none=False)
self.T_base_p = parameter(T_base_p, self.varshape, allow_none=False)
self.T_base_q = parameter(T_base_q, self.varshape, allow_none=False)
self.V_sh = parameter(V_sh, self.varshape, allow_none=False)

def f_p_alpha(self, V):
temp = -27 - V + self.V_sh
return 0.055 * temp / (bm.exp(temp / 3.8) - 1)

def f_p_beta(self, V):
return 0.94 * bm.exp((-75. - V + self.V_sh) / 17.)

def f_q_alpha(self, V):
return 0.000457 * bm.exp((-13. - V + self.V_sh) / 50.)

def f_q_beta(self, V):
return 0.0065 / (bm.exp((-15. - V + self.V_sh) / 28.) + 1.)


class ICaL_IS2008(ICa_p2q_ss):
r"""The L-type calcium channel model proposed by (Inoue & Strowbridge, 2008) [1]_.

The L-type calcium channel model is adopted from (Inoue, et, al., 2008) [1]_.
Its dynamics is given by:

.. math::

I_{CaL} &= g_{max} p^2 q(V-E_{Ca}) \\
{dp \over dt} &= {\phi_p \cdot (p_{\infty}-p)\over \tau_p} \\
& p_{\infty} = {1 \over 1+\exp [-(V+10-V_{sh}) / 4.]} \\
& \tau_{p} = 0.4+{0.7 \over \exp [(V+5-V_{sh}) / 15]+\exp [-(V+5-V_{sh}) / 15]} \\
{dq \over dt} &= {\phi_q \cdot (q_{\infty}-q) \over \tau_q} \\
& q_{\infty} = {1 \over 1+\exp [(V+25-V_{sh}) / 2]} \\
& \tau_q = 300 + {100 \over \exp [(V+40-V_{sh}) / 9.5]+\exp [-(V+40-V_{sh}) / 9.5]}

where :math:`phi_p = 3.55^{\frac{T-24}{10}}` and :math:`phi_q = 3^{\frac{T-24}{10}}`
are temperature-dependent factors (:math:`T` is the temperature in Celsius),
:math:`E_{Ca}` is the reversal potential of Calcium channel.

Parameters
----------
T : float
The temperature.
T_base_p : float
The base temperature factor of :math:`p` channel.
T_base_q : float
The base temperature factor of :math:`q` channel.
g_max : float
The maximum conductance.
V_sh : float
The membrane potential shift.

References
----------

.. [1] Inoue, Tsuyoshi, and Ben W. Strowbridge. "Transient activity induces a long-lasting
increase in the excitability of olfactory bulb interneurons." Journal of
neurophysiology 99, no. 1 (2008): 187-199.

See Also
--------
ICa_p2q_form
"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
T: Union[float, Array, Initializer, Callable] = 36.,
T_base_p: Union[float, Array, Initializer, Callable] = 3.55,
T_base_q: Union[float, Array, Initializer, Callable] = 3.,
g_max: Union[float, Array, Initializer, Callable] = 1.,
V_sh: Union[float, Array, Initializer, Callable] = 0.,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
super(ICaL_IS2008, self).__init__(size,
keep_size=keep_size,
name=name,
method=method,
g_max=g_max,
phi_p=T_base_p ** ((T - 24) / 10),
phi_q=T_base_q ** ((T - 24) / 10),
mode=mode)

# parameters
self.T = parameter(T, self.varshape, allow_none=False)
self.T_base_p = parameter(T_base_p, self.varshape, allow_none=False)
self.T_base_q = parameter(T_base_q, self.varshape, allow_none=False)
self.V_sh = parameter(V_sh, self.varshape, allow_none=False)

def f_p_inf(self, V):
return 1. / (1 + bm.exp(-(V + 10. - self.V_sh) / 4.))

def f_p_tau(self, V):
return 0.4 + .7 / (bm.exp(-(V + 5. - self.V_sh) / 15.) +
bm.exp((V + 5. - self.V_sh) / 15.))

def f_q_inf(self, V):
return 1. / (1. + bm.exp((V + 25. - self.V_sh) / 2.))

def f_q_tau(self, V):
return 300. + 100. / (bm.exp((V + 40 - self.V_sh) / 9.5) +
bm.exp(-(V + 40 - self.V_sh) / 9.5))

+ 0
- 862
brainpy/dyn/channels/Ca_channels.py View File

@@ -1,862 +0,0 @@
# -*- coding: utf-8 -*-

from typing import Union, Callable

import brainpy.math as bm
from brainpy.dyn.base import Container, ConNeuGroup
from brainpy.initialize import OneInit, Initializer, init_param
from brainpy.integrators.joint_eq import JointEq
from brainpy.integrators.ode import odeint
from brainpy.types import Shape, Tensor
from .base import Ion, IonChannel

__all__ = [
'Calcium',
'CalciumFixed',
'CalciumDetailed',
'CalciumAbstract',

'CalciumChannel',
'IAHP',
'ICaN',
'ICaT',
'ICaT_RE',
'ICaHT',
'ICaL',
]


class Calcium(Ion, Container):
"""The base calcium dynamics.

Parameters
----------
size: int, sequence of int
The size of the simulation target.
method: str
The numerical integration method.
name: str
The name of the object.
**channels
The calcium dependent channels.
"""

'''The type of the master object.'''
master_cls = ConNeuGroup

def __init__(
self,
size: Shape,
method: str = 'exp_auto',
name: str = None,
**channels
):
Ion.__init__(self, size)
Container.__init__(self, name=name, **channels)
self.method = method

def current(self, V, C_Ca=None, E_Ca=None):
C_Ca = self.C if C_Ca is None else C_Ca
E_Ca = self.E if E_Ca is None else E_Ca
nodes = list(self.implicit_nodes.values())
current = nodes[0].current(V, C_Ca, E_Ca)
for node in nodes[1:]:
current += node.current(V, C_Ca, E_Ca)
return current


class CalciumFixed(Calcium):
"""Fixed Calcium dynamics.

This calcium model has no dynamics. It only holds a fixed reversal potential :math:`E`.
"""

def __init__(
self,
size: Shape,
E: Union[float, Tensor, Initializer, Callable] = 120.,
C: Union[float, Tensor, Initializer, Callable] = 0.05,
method: str = 'exp_auto',
name: str = None,
**channels
):
super(CalciumFixed, self).__init__(size, method=method, name=name, **channels)
self.E = init_param(E, self.num, allow_none=False)
self.C = init_param(C, self.num, allow_none=False)

def update(self, t, dt, V):
for node in self.implicit_nodes.values():
node.update(t, dt, V, self.C, self.E)

def reset(self, V, C_Ca=None, E_Ca=None):
C_Ca = self.C if C_Ca is None else C_Ca
E_Ca = self.E if E_Ca is None else E_Ca
for node in self.implicit_nodes.values():
node.reset(V, C_Ca, E_Ca)


class CalciumDetailed(Calcium):
r"""Dynamical Calcium model.

**1. The dynamics of intracellular** :math:`Ca^{2+}`

The dynamics of intracellular :math:`Ca^{2+}` were determined by two contributions [1]_ :

*(i) Influx of* :math:`Ca^{2+}` *due to Calcium currents*

:math:`Ca^{2+}` ions enter through :math:`Ca^{2+}` channels and diffuse into the
interior of the cell. Only the :math:`Ca^{2+}` concentration in a thin shell beneath
the membrane was modeled. The influx of :math:`Ca^{2+}` into such a thin shell followed:

.. math::

[Ca]_{i}=-\frac{k}{2 F d} I_{Ca}

where :math:`F=96489\, \mathrm{C\, mol^{-1}}` is the Faraday constant,
:math:`d=1\, \mathrm{\mu m}` is the depth of the shell beneath the membrane,
the unit conversion constant is :math:`k=0.1` for :math:`I_T` in
:math:`\mathrm{\mu A/cm^{2}}` and :math:`[Ca]_{i}` in millimolar,
and :math:`I_{Ca}` is the summation of all :math:`Ca^{2+}` currents.

*(ii) Efflux of* :math:`Ca^{2+}` *due to an active pump*

In a thin shell beneath the membrane, :math:`Ca^{2+}` retrieval usually consists of a
combination of several processes, such as binding to :math:`Ca^{2+}` buffers, calcium
efflux due to :math:`Ca^{2+}` ATPase pump activity and diffusion to neighboring shells.
Only the :math:`Ca^{2+}` pump was modeled here. We adopted the following kinetic scheme:

.. math::

Ca _{i}^{2+}+ P \overset{c_1}{\underset{c_2}{\rightleftharpoons}} CaP \xrightarrow{c_3} P+ Ca _{0}^{2+}

where P represents the :math:`Ca^{2+}` pump, CaP is an intermediate state,
:math:`Ca _{ o }^{2+}` is the extracellular :math:`Ca^{2+}` concentration,
and :math:`c_{1}, c_{2}` and :math:`c_{3}` are rate constants. :math:`Ca^{2+}`
ions have a high affinity for the pump :math:`P`, whereas extrusion of
:math:`Ca^{2+}` follows a slower process (Blaustein, 1988 ). Therefore,
:math:`c_{3}` is low compared to :math:`c_{1}` and :math:`c_{2}` and the
Michaelis-Menten approximation can be used for describing the kinetics of the pump.
According to such a scheme, the kinetic equation for the :math:`Ca^{2+}` pump is:

.. math::

\frac{[Ca^{2+}]_{i}}{dt}=-\frac{K_{T}[Ca]_{i}}{[Ca]_{i}+K_{d}}

where :math:`K_{T}=10^{-4}\, \mathrm{mM\, ms^{-1}}` is the product of :math:`c_{3}`
with the total concentration of :math:`P` and :math:`K_{d}=c_{2} / c_{1}=10^{-4}\, \mathrm{mM}`
is the dissociation constant, which can be interpreted here as the value of
:math:`[Ca]_{i}` at which the pump is half activated (if :math:`[Ca]_{i} \ll K_{d}`
then the efflux is negligible).

**2.A simple first-order model**

While, in (Bazhenov, et al., 1998) [2]_, the :math:`Ca^{2+}` dynamics is
described by a simple first-order model,

.. math::

\frac{d\left[Ca^{2+}\right]_{i}}{d t}=-\frac{I_{Ca}}{z F d}+\frac{\left[Ca^{2+}\right]_{rest}-\left[C a^{2+}\right]_{i}}{\tau_{Ca}}

where :math:`I_{Ca}` is the summation of all :math:`Ca ^{2+}` currents, :math:`d`
is the thickness of the perimembrane "shell" in which calcium is able to affect
membrane properties :math:`(1.\, \mathrm{\mu M})`, :math:`z=2` is the valence of the
:math:`Ca ^{2+}` ion, :math:`F` is the Faraday constant, and :math:`\tau_{C a}` is
the :math:`Ca ^{2+}` removal rate. The resting :math:`Ca ^{2+}` concentration was
set to be :math:`\left[ Ca ^{2+}\right]_{\text {rest}}=.05\, \mathrm{\mu M}` .

**3. The reversal potential**

The reversal potential of calcium :math:`Ca ^{2+}` is calculated according to the
Nernst equation:

.. math::

E = k'{RT \over 2F} log{[Ca^{2+}]_0 \over [Ca^{2+}]_i}

where :math:`R=8.31441 \, \mathrm{J} /(\mathrm{mol}^{\circ} \mathrm{K})`,
:math:`T=309.15^{\circ} \mathrm{K}`,
:math:`F=96,489 \mathrm{C} / \mathrm{mol}`,
and :math:`\left[\mathrm{Ca}^{2+}\right]_{0}=2 \mathrm{mM}`.

Parameters
----------
d : float
The thickness of the peri-membrane "shell".
F : float
The Faraday constant. (:math:`C*mmol^{-1}`)
tau : float
The time constant of the :math:`Ca ^{2+}` removal rate. (ms)
C_rest : float
The resting :math:`Ca ^{2+}` concentration.
C_0 : float
The :math:`Ca ^{2+}` concentration outside of the membrane.
R : float
The gas constant. (:math:` J*mol^{-1}*K^{-1}`)

References
----------

.. [1] Destexhe, Alain, Agnessa Babloyantz, and Terrence J. Sejnowski. "Ionic mechanisms for intrinsic slow oscillations in thalamic relay neurons." Biophysical journal 65, no. 4 (1993): 1538-1552.
.. [2] Bazhenov, Maxim, Igor Timofeev, Mircea Steriade, and Terrence J. Sejnowski. "Cellular and network models for intrathalamic augmenting responses during 10-Hz stimulation." Journal of neurophysiology 79, no. 5 (1998): 2730-2748.

"""

R = 8.31441 # gas constant, J*mol-1*K-1
F = 96.489 # the Faraday constant

def __init__(
self,
size: Shape,
d: Union[float, Tensor, Initializer, Callable] = 1.,
C_rest: Union[float, Tensor, Initializer, Callable] = 0.05,
tau: Union[float, Tensor, Initializer, Callable] = 5.,
C_0: Union[float, Tensor, Initializer, Callable] = 2.,
T: Union[float, Tensor, Initializer, Callable] = 36.,
C_initializer: Union[Initializer, Callable, Tensor] = OneInit(0.05),
E_initializer: Union[Initializer, Callable, Tensor] = OneInit(120.),
method: str = 'exp_auto',
name: str = None,
**channels
):
super(CalciumDetailed, self).__init__(size, method=method, name=name, **channels)

# parameters
self.T = init_param(T, self.num, allow_none=False) # temperature
self.d = init_param(d, self.num, allow_none=False)
self.tau = init_param(tau, self.num, allow_none=False)
self.C_rest = init_param(C_rest, self.num, allow_none=False)
self.C_0 = init_param(C_0, self.num, allow_none=False)
self._E_initializer = E_initializer
self._C_initializer = C_initializer

# variables
self.C = bm.Variable(init_param(C_initializer, self.num)) # Calcium concentration
self.E = bm.Variable(init_param(E_initializer, self.num)) # Reversal potential

# function
self.integral = odeint(self.derivative, method=method)

def reset(self, V, C_Ca=None, E_Ca=None):
self.C[:] = init_param(self._C_initializer, self.num) if (C_Ca is None) else C_Ca
self.E[:] = init_param(self._E_initializer, self.num) if (E_Ca is None) else E_Ca
for node in self.implicit_nodes.values():
node.reset(V, self.C, self.E)

def derivative(self, C, t, V):
ICa = self.current(V, C, self.E)
return - ICa / (2 * self.F * self.d) + (self.C_rest - C) / self.tau

def update(self, t, dt, V):
C = self.integral(self.C.value, t, V, dt)
for node in self.implicit_nodes.values():
node.update(t, dt, V, self.C, self.E)
self.E.value = self.R * (273.15 + self.T) / (2 * self.F) * bm.log(self.C_0 / C)
self.C.value = C


class CalciumAbstract(Calcium):
r"""The first-order calcium concentration model.

.. math::

Ca' = -\alpha I_{Ca} + -\beta Ca



"""
def __init__(
self,
size: Shape,
alpha: Union[float, Tensor, Initializer, Callable] = 0.13,
beta: Union[float, Tensor, Initializer, Callable] = 0.075,
C_initializer: Union[Initializer, Callable, Tensor] = OneInit(0.05),
E_initializer: Union[Initializer, Callable, Tensor] = OneInit(120.),
method: str = 'exp_auto',
name: str = None
):
super(CalciumAbstract, self).__init__(size, name=name)

# parameters
self.alpha = init_param(alpha, self.num, allow_none=False)
self.beta = init_param(beta, self.num, allow_none=False)

# variables
self.C = bm.Variable(init_param(C_initializer, self.num)) # Calcium concentration
self.E = bm.Variable(init_param(E_initializer, self.num)) # Reversal potential

# functions
self.integral = odeint(self.derivative, method=method)

def reset(self, V, C_Ca=None, E_Ca=None):
self.C[:] = init_param(self._C_initializer, self.num) if (C_Ca is None) else C_Ca
self.E[:] = init_param(self._E_initializer, self.num) if (E_Ca is None) else E_Ca
for node in self.implicit_nodes.values():
node.reset(V, self.C, self.E)

def derivative(self, C, t, V):
ICa = self.current(V, C, self.E)
return - self.alpha * ICa - self.beta * C

def update(self, t, dt, V):
C = self.integral(self.C.value, t, V, dt)
for node in self.implicit_nodes.values():
node.update(t, dt, V, self.C, self.E)
self.E.value = self.R * (273.15 + self.T) / (2 * self.F) * bm.log(self.C_0 / C)
self.C.value = C


# -------------------------


class CalciumChannel(IonChannel):
"""Base class for Calcium ion channels."""

'''The type of the master object.'''
master_cls = Calcium

def update(self, t, dt, V, C_Ca, E_Ca):
raise NotImplementedError

def current(self, V, C_Ca, E_Ca):
raise NotImplementedError

def reset(self, V, C_Ca, E_Ca):
raise NotImplementedError


class IAHP(CalciumChannel):
r"""The calcium-dependent potassium current model.

The dynamics of the calcium-dependent potassium current model is given by:

.. math::

\begin{aligned}
I_{AHP} &= g_{\mathrm{max}} p (V - E) \\
{dp \over dt} &= {p_{\infty}(V) - p \over \tau_p(V)} \\
p_{\infty} &=\frac{48[Ca^{2+}]_i}{\left(48[Ca^{2+}]_i +0.09\right)} \\
\tau_p &=\frac{1}{\left(48[Ca^{2+}]_i +0.09\right)}
\end{aligned}

where :math:`E` is the reversal potential, :math:`g_{max}` is the maximum conductance.


Parameters
----------
g_max : float
The maximal conductance density (:math:`mS/cm^2`).
E : float
The reversal potential (mV).

References
----------

.. [1] Contreras, D., R. Curró Dossi, and M. Steriade. "Electrophysiological
properties of cat reticular thalamic neurones in vivo." The Journal of
Physiology 470.1 (1993): 273-294.
.. [2] Mulle, Ch, Anamaria Madariaga, and M. Deschênes. "Morphology and
electrophysiological properties of reticularis thalami neurons in
cat: in vivo study of a thalamic pacemaker." Journal of
Neuroscience 6.8 (1986): 2134-2145.
.. [3] Avanzini, G., et al. "Intrinsic properties of nucleus reticularis
thalami neurones of the rat studied in vitro." The Journal of
Physiology 416.1 (1989): 111-122.
.. [4] Destexhe, Alain, et al. "A model of spindle rhythmicity in the isolated
thalamic reticular nucleus." Journal of neurophysiology 72.2 (1994): 803-818.
.. [5] Vijayan S, Kopell NJ (2012) Thalamic model of awake alpha oscillations and
implications for stimulus processing. Proc Natl Acad Sci USA 109: 18553–18558.

"""

'''The type of the master object.'''
master_cls = CalciumDetailed

def __init__(
self,
size: Shape,
E: Union[float, Tensor, Initializer, Callable] = -80.,
g_max: Union[float, Tensor, Initializer, Callable] = 1.,
method: str = 'exp_auto',
name: str = None
):
super(IAHP, self).__init__(size, name=name)

# parameters
self.E = init_param(E, self.num, allow_none=False)
self.g_max = init_param(g_max, self.num, allow_none=False)

# variables
self.p = bm.Variable(bm.zeros(self.num))

# function
self.integral = odeint(self.derivative, method=method)

def derivative(self, p, t, V, C_Ca, E_Ca):
C2 = 48 * C_Ca ** 2
C3 = C2 + 0.09
return (C2 / C3 - p) * C3

def update(self, t, dt, V, C_Ca, E_Ca):
self.p.value = self.integral(self.p, t, C=C_Ca, dt=dt)

def current(self, V, C_Ca, E_Ca):
return self.g_max * self.p * (self.E - V)

def reset(self, V, C_Ca, E_Ca):
C2 = 48 * C_Ca ** 2
C3 = C2 + 0.09
self.p.value = C2 / C3


class ICaN(CalciumChannel):
r"""The calcium-activated non-selective cation channel model.

The dynamics of the calcium-activated non-selective cation channel model is given by:

.. math::

\begin{aligned}
I_{CAN} &=g_{\mathrm{max}} M\left([Ca^{2+}]_{i}\right) p \left(V-E\right)\\
&M\left([Ca^{2+}]_{i}\right) ={[Ca^{2+}]_{i} \over 0.2+[Ca^{2+}]_{i}} \\
&{dp \over dt} = {\phi \cdot (p_{\infty}-p)\over \tau_p} \\
&p_{\infty} = {1.0 \over 1 + \exp(-(V + 43) / 5.2)} \\
&\tau_{p} = {2.7 \over \exp(-(V + 55) / 15) + \exp((V + 55) / 15)} + 1.6
\end{aligned}

where :math:`\phi` is the temperature factor.

Parameters
----------
g_max : float
The maximal conductance density (:math:`mS/cm^2`).
E : float
The reversal potential (mV).
phi : float
The temperature factor.

References
----------

.. [1] Destexhe, Alain, et al. "A model of spindle rhythmicity in the isolated
thalamic reticular nucleus." Journal of neurophysiology 72.2 (1994): 803-818.
.. [2] Inoue T, Strowbridge BW (2008) Transient activity induces a long-lasting
increase in the excitability of olfactory bulb interneurons.
J Neurophysiol 99: 187–199.
"""

'''The type of the master object.'''
master_cls = CalciumDetailed

def __init__(
self,
size: Shape,
E: Union[float, Tensor, Initializer, Callable] = 10.,
g_max: Union[float, Tensor, Initializer, Callable] = 1.,
phi: Union[float, Tensor, Initializer, Callable] = 1.,
method: str = 'exp_auto',
name: str = None
):
super(ICaN, self).__init__(size, name=name)

# parameters
self.E = init_param(E, self.num, allow_none=False)
self.g_max = init_param(g_max, self.num, allow_none=False)
self.phi = init_param(phi, self.num, allow_none=False)

# variables
self.p = bm.Variable(bm.zeros(self.num))

# function
self.integral = odeint(self.derivative, method=method)

def derivative(self, p, t, V):
phi_p = 1.0 / (1 + bm.exp(-(V + 43.) / 5.2))
p_inf = 2.7 / (bm.exp(-(V + 55.) / 15.) + bm.exp((V + 55.) / 15.)) + 1.6
return self.phi * (phi_p - p) / p_inf

def update(self, t, dt, V, C_Ca, E_Ca):
self.p.value = self.integral(self.p, t, V, dt)

def current(self, V, C_Ca, E_Ca):
M = C_Ca / (C_Ca + 0.2)
g = self.g_max * M * self.p
return g * (self.E - V)

def reset(self, V, C_Ca, E_Ca):
self.p.value = 1.0 / (1 + bm.exp(-(V + 43.) / 5.2))


class ICaT(CalciumChannel):
r"""The low-threshold T-type calcium current model.

The dynamics of the low-threshold T-type calcium current model [1]_ is given by:

.. math::

I_{CaT} &= g_{max} p^2 q(V-E_{Ca}) \\
{dp \over dt} &= {\phi_p \cdot (p_{\infty}-p)\over \tau_p} \\
&p_{\infty} = {1 \over 1+\exp [-(V+59-V_{sh}) / 6.2]} \\
&\tau_{p} = 0.612 + {1 \over \exp [-(V+132.-V_{sh}) / 16.7]+\exp [(V+16.8-V_{sh}) / 18.2]} \\
{dq \over dt} &= {\phi_q \cdot (q_{\infty}-q) \over \tau_q} \\
&q_{\infty} = {1 \over 1+\exp [(V+83-V_{sh}) / 4]} \\
& \begin{array}{l} \tau_{q} = \exp \left(\frac{V+467-V_{sh}}{66.6}\right) \quad V< (-80 +V_{sh})\, mV \\
\tau_{q} = \exp \left(\frac{V+22-V_{sh}}{-10.5}\right)+28 \quad V \geq (-80 + V_{sh})\, mV \end{array}

where :math:`phi_p = 3.55^{\frac{T-24}{10}}` and :math:`phi_q = 3^{\frac{T-24}{10}}`
are temperature-dependent factors (:math:`T` is the temperature in Celsius),
:math:`E_{Ca}` is the reversal potential of Calcium channel.

Parameters
----------
T : float
The temperature.
T_base_p : float
The base temperature factor of :math:`p` channel.
T_base_q : float
The base temperature factor of :math:`q` channel.
g_max : float
The maximum conductance.
V_sh : float
The membrane potential shift.

References
----------

.. [1] Huguenard JR, McCormick DA (1992) Simulation of the currents involved in
rhythmic oscillations in thalamic relay neurons. J Neurophysiol 68:1373–1383.
"""

def __init__(
self,
size: Shape,
T: Union[float, Tensor, Initializer, Callable] = 36.,
T_base_p: Union[float, Tensor, Initializer, Callable] = 3.55,
T_base_q: Union[float, Tensor, Initializer, Callable] = 3.,
g_max: Union[float, Tensor, Initializer, Callable] = 2.,
V_sh: Union[float, Tensor, Initializer, Callable] = -3.,
method: str = 'exp_auto',
name: str = None
):
super(ICaT, self).__init__(size, name=name)

# parameters
self.T = init_param(T, self.num, allow_none=False)
self.T_base_p = init_param(T_base_p, self.num, allow_none=False)
self.T_base_q = init_param(T_base_q, self.num, allow_none=False)
self.g_max = init_param(g_max, self.num, allow_none=False)
self.V_sh = init_param(V_sh, self.num, allow_none=False)
self.phi_p = self.T_base_p ** ((self.T - 24) / 10)
self.phi_q = self.T_base_q ** ((self.T - 24) / 10)

# variables
self.p = bm.Variable(bm.zeros(self.num))
self.q = bm.Variable(bm.zeros(self.num))

# functions
self.integral = odeint(JointEq([self.dp, self.dq]), method=method)

def dp(self, p, t, V):
p_inf = 1. / (1 + bm.exp(-(V + 59. - self.V_sh) / 6.2))
p_tau = 1. / (bm.exp(-(V + 132. - self.V_sh) / 16.7) + bm.exp((V + 16.8 - self.V_sh) / 18.2)) + 0.612
return self.phi_p * (p_inf - p) / p_tau

def dq(self, q, t, V):
q_inf = 1. / (1. + bm.exp((V + 83. - self.V_sh) / 4.0))
q_tau = bm.where(V >= (-80. + self.V_sh),
bm.exp(-(V + 22. - self.V_sh) / 10.5) + 28.,
bm.exp((V + 467. - self.V_sh) / 66.6))
return self.phi_q * (q_inf - q) / q_tau

def update(self, t, dt, V, C_Ca, E_Ca):
self.p.value, self.q.value = self.integral(self.p, self.q, t, V, dt)

def current(self, V, C_Ca, E_Ca):
return self.g_max * self.p * self.p * self.q * (E_Ca - V)

def reset(self, V, C_Ca, E_Ca):
self.p.value = 1. / (1 + bm.exp(-(V + 59. - self.V_sh) / 6.2))
self.q.value = 1. / (1. + bm.exp((V + 83. - self.V_sh) / 4.0))


class ICaT_RE(CalciumChannel):
r"""The low-threshold T-type calcium current model in thalamic reticular nucleus.

The dynamics of the low-threshold T-type calcium current model [1]_ [2]_ in thalamic
reticular nucleus neurons is given by:

.. math::

I_{CaT} &= g_{max} p^2 q(V-E_{Ca}) \\
{dp \over dt} &= {\phi_p \cdot (p_{\infty}-p)\over \tau_p} \\
&p_{\infty} = {1 \over 1+\exp [-(V+52-V_{sh}) / 7.4]} \\
&\tau_{p} = 3+{1 \over \exp [(V+27-V_{sh}) / 10]+\exp [-(V+102-V_{sh}) / 15]} \\
{dq \over dt} &= {\phi_q \cdot (q_{\infty}-q) \over \tau_q} \\
&q_{\infty} = {1 \over 1+\exp [(V+80-V_{sh}) / 5]} \\
& \tau_q = 85+ {1 \over \exp [(V+48-V_{sh}) / 4]+\exp [-(V+407-V_{sh}) / 50]}

where :math:`phi_p = 5^{\frac{T-24}{10}}` and :math:`phi_q = 3^{\frac{T-24}{10}}`
are temperature-dependent factors (:math:`T` is the temperature in Celsius),
:math:`E_{Ca}` is the reversal potential of Calcium channel.

Parameters
----------
T : float
The temperature.
T_base_p : float
The base temperature factor of :math:`p` channel.
T_base_q : float
The base temperature factor of :math:`q` channel.
g_max : float
The maximum conductance.
V_sh : float
The membrane potential shift.

References
----------

.. [1] Avanzini, G., et al. "Intrinsic properties of nucleus reticularis thalami
neurones of the rat studied in vitro." The Journal of
Physiology 416.1 (1989): 111-122.
.. [2] Bal, Thierry, and DAVID A. McCORMICK. "Mechanisms of oscillatory activity
in guinea‐pig nucleus reticularis thalami in vitro: a mammalian
pacemaker." The Journal of Physiology 468.1 (1993): 669-691.

"""

def __init__(
self,
size: Shape,
T: Union[float, Tensor, Initializer, Callable] = 36.,
T_base_p: Union[float, Tensor, Initializer, Callable] = 5.,
T_base_q: Union[float, Tensor, Initializer, Callable] = 3.,
g_max: Union[float, Tensor, Initializer, Callable] = 1.75,
V_sh: Union[float, Tensor, Initializer, Callable] = -3.,
method='exp_auto',
name=None
):
super(ICaT_RE, self).__init__(size, name=name)

# parameters
self.T = init_param(T, self.num, allow_none=False)
self.T_base_p = init_param(T_base_p, self.num, allow_none=False)
self.T_base_q = init_param(T_base_q, self.num, allow_none=False)
self.g_max = init_param(g_max, self.num, allow_none=False)
self.V_sh = init_param(V_sh, self.num, allow_none=False)
self.phi_p = self.T_base_p ** ((self.T - 24) / 10)
self.phi_q = self.T_base_q ** ((self.T - 24) / 10)

# variables
self.p = bm.Variable(bm.zeros(self.num))
self.q = bm.Variable(bm.zeros(self.num))

# function
self.integral = odeint(JointEq([self.dp, self.dq]), method=method)

def dp(self, p, t, V):
p_inf = 1. / (1. + bm.exp(-(V + 52. - self.V_sh) / 7.4))
p_tau = 3. + 1. / (bm.exp((V + 27. - self.V_sh) / 10.) + bm.exp(-(V + 102. - self.V_sh) / 15.))
return self.phi_p * (p_inf - p) / p_tau

def dq(self, q, t, V):
q_inf = 1. / (1. + bm.exp((V + 80. - self.V_sh) / 5.))
q_tau = 85. + 1. / (bm.exp((V + 48. - self.V_sh) / 4.) + bm.exp(-(V + 407. - self.V_sh) / 50.))
return self.phi_q * (q_inf - q) / q_tau

def update(self, t, dt, V, C_Ca, E_Ca):
self.p.value, self.q.value = self.integral(self.p, self.q, t, V, dt)

def current(self, V, C_Ca, E_Ca):
return self.g_max * self.p * self.p * self.q * (E_Ca - V)

def reset(self, V, C_Ca, E_Ca):
self.p.value = 1. / (1. + bm.exp(-(V + 52. - self.V_sh) / 7.4))
self.q.value = 1. / (1. + bm.exp((V + 80. - self.V_sh) / 5.))


class ICaHT(CalciumChannel):
r"""The high-threshold T-type calcium current model.

The high-threshold T-type calcium current model is adopted from [1]_.
Its dynamics is given by

.. math::

\begin{aligned}
I_{\mathrm{Ca/HT}} &= g_{\mathrm{max}} p^2 q (V-E_{Ca})
\\
{dp \over dt} &= {\phi_{p} \cdot (p_{\infty} - p) \over \tau_{p}} \\
&\tau_{p} =\frac{1}{\exp \left(\frac{V+132-V_{sh}}{-16.7}\right)+\exp \left(\frac{V+16.8-V_{sh}}{18.2}\right)}+0.612 \\
& p_{\infty} = {1 \over 1+exp[-(V+59-V_{sh}) / 6.2]}
\\
{dq \over dt} &= {\phi_{q} \cdot (q_{\infty} - h) \over \tau_{q}} \\
& \begin{array}{l} \tau_q = \exp \left(\frac{V+467-V_{sh}}{66.6}\right) \quad V< (-80 +V_{sh})\, mV \\
\tau_q = \exp \left(\frac{V+22-V_{sh}}{-10.5}\right)+28 \quad V \geq (-80 + V_{sh})\, mV \end{array} \\
&q_{\infty} = {1 \over 1+exp[(V+83 -V_{shift})/4]}
\end{aligned}

where :math:`phi_p = 3.55^{\frac{T-24}{10}}` and :math:`phi_q = 3^{\frac{T-24}{10}}`
are temperature-dependent factors (:math:`T` is the temperature in Celsius),
:math:`E_{Ca}` is the reversal potential of Calcium channel.

Parameters
----------
T : float
The temperature.
T_base_p : float
The base temperature factor of :math:`p` channel.
T_base_q : float
The base temperature factor of :math:`q` channel.
g_max : float
The maximum conductance.
V_sh : float
The membrane potential shift.

References
----------
.. [1] Huguenard JR, McCormick DA (1992) Simulation of the currents involved in
rhythmic oscillations in thalamic relay neurons. J Neurophysiol 68:1373–1383.
"""

def __init__(
self,
size: Shape,
T: Union[float, Tensor, Initializer, Callable] = 36.,
T_base_p: Union[float, Tensor, Initializer, Callable] = 3.55,
T_base_q: Union[float, Tensor, Initializer, Callable] = 3.,
g_max: Union[float, Tensor, Initializer, Callable] = 2.,
V_sh: Union[float, Tensor, Initializer, Callable] = 25.,
method: str = 'exp_auto',
name: str = None
):
super(ICaHT, self).__init__(size, name=name)

# parameters
self.T = init_param(T, self.num, allow_none=False)
self.T_base_p = init_param(T_base_p, self.num, allow_none=False)
self.T_base_q = init_param(T_base_q, self.num, allow_none=False)
self.g_max = init_param(g_max, self.num, allow_none=False)
self.V_sh = init_param(V_sh, self.num, allow_none=False)
self.phi_p = self.T_base_p ** ((self.T - 24) / 10)
self.phi_q = self.T_base_q ** ((self.T - 24) / 10)

# variables
self.p = bm.Variable(bm.zeros(self.num))
self.q = bm.Variable(bm.zeros(self.num))

# function
self.integral = odeint(JointEq([self.dp, self.dq]), method=method)

def dp(self, p, t, V):
p_inf = 1. / (1. + bm.exp(-(V + 59. - self.V_sh) / 6.2))
p_tau = 1. / (bm.exp(-(V + 132. - self.V_sh) / 16.7) + bm.exp((V + 16.8 - self.V_sh) / 18.2)) + 0.612
return self.phi_p * (p_inf - p) / p_tau

def dq(self, q, t, V):
q_inf = 1. / (1. + bm.exp((V + 83. - self.V_sh) / 4.))
q_tau = bm.where(V >= (-80. + self.V_sh),
bm.exp(-(V + 22. - self.V_sh) / 10.5) + 28.,
bm.exp((V + 467. - self.V_sh) / 66.6))
return self.phi_q * (q_inf - q) / q_tau

def update(self, t, dt, V, C_Ca, E_Ca):
self.p.value, self.q.value = self.integral(self.p, self.q, t, V, dt)

def current(self, V, C_Ca, E_Ca):
return self.g_max * self.p * self.p * self.q * (E_Ca - V)

def reset(self, V, C_Ca, E_Ca):
self.p.value = 1. / (1. + bm.exp(-(V + 59. - self.V_sh) / 6.2))
self.q.value = 1. / (1. + bm.exp((V + 83. - self.V_sh) / 4.))


class ICaL(CalciumChannel):
r"""The L-type calcium channel model.

The L-type calcium channel model is adopted from (Inoue, et, al., 2008) [1]_.
Its dynamics is given by:

.. math::

I_{CaL} &= g_{max} p^2 q(V-E_{Ca}) \\
{dp \over dt} &= {\phi_p \cdot (p_{\infty}-p)\over \tau_p} \\
&p_{\infty} = {1 \over 1+\exp [-(V+10-V_{sh}) / 4.]} \\
&\tau_{p} = 0.4+{0.7 \over \exp [(V+5-V_{sh}) / 15]+\exp [-(V+5-V_{sh}) / 15]} \\
{dq \over dt} &= {\phi_q \cdot (q_{\infty}-q) \over \tau_q} \\
&q_{\infty} = {1 \over 1+\exp [(V+25-V_{sh}) / 2]} \\
&\tau_q = 300 + {100 \over \exp [(V+40-V_{sh}) / 9.5]+\exp [-(V+40-V_{sh}) / 9.5]}

where :math:`phi_p = 3.55^{\frac{T-24}{10}}` and :math:`phi_q = 3^{\frac{T-24}{10}}`
are temperature-dependent factors (:math:`T` is the temperature in Celsius),
:math:`E_{Ca}` is the reversal potential of Calcium channel.

Parameters
----------
T : float
The temperature.
T_base_p : float
The base temperature factor of :math:`p` channel.
T_base_q : float
The base temperature factor of :math:`q` channel.
g_max : float
The maximum conductance.
V_sh : float
The membrane potential shift.

References
----------

.. [1] Inoue, Tsuyoshi, and Ben W. Strowbridge. "Transient activity induces a long-lasting
increase in the excitability of olfactory bulb interneurons." Journal of
neurophysiology 99, no. 1 (2008): 187-199.
"""

def __init__(
self,
size: Shape,
T: Union[float, Tensor, Initializer, Callable] = 36.,
T_base_p: Union[float, Tensor, Initializer, Callable] = 3.55,
T_base_q: Union[float, Tensor, Initializer, Callable] = 3.,
g_max: Union[float, Tensor, Initializer, Callable] = 1.,
V_sh: Union[float, Tensor, Initializer, Callable] = 0.,
method: str = 'exp_auto',
name: str = None
):
super(ICaL, self).__init__(size, name=name)

# parameters
self.T = init_param(T, self.num, allow_none=False)
self.T_base_p = init_param(T_base_p, self.num, allow_none=False)
self.T_base_q = init_param(T_base_q, self.num, allow_none=False)
self.g_max = init_param(g_max, self.num, allow_none=False)
self.V_sh = init_param(V_sh, self.num, allow_none=False)
self.phi_p = self.T_base_p ** ((self.T - 24) / 10)
self.phi_q = self.T_base_q ** ((self.T - 24) / 10)

# variables
self.p = bm.Variable(bm.zeros(self.num))
self.q = bm.Variable(bm.zeros(self.num))

# function
self.integral = odeint(JointEq([self.dp, self.dq]), method=method)

def dp(self, p, t, V):
p_inf = 1. / (1 + bm.exp(-(V + 10. - self.V_sh) / 4.))
p_tau = 0.4 + .7 / (bm.exp(-(V + 5. - self.V_sh) / 15.) + bm.exp((V + 5. - self.V_sh) / 15.))
dpdt = self.phi_p * (p_inf - p) / p_tau
return dpdt

def dq(self, q, t, V):
q_inf = 1. / (1. + bm.exp((V + 25. - self.V_sh) / 2.))
q_tau = 300. + 100. / (bm.exp((V + 40 - self.V_sh) / 9.5) + bm.exp(-(V + 40 - self.V_sh) / 9.5))
dqdt = self.phi_q * (q_inf - q) / q_tau
return dqdt

def update(self, t, dt, V, C_Ca, E_Ca):
self.p.value, self.q.value = self.integral(self.p, self.q, t, V, dt)

def current(self, V, C_Ca, E_Ca):
return self.g_max * self.p * self.p * self.q * (E_Ca - V)

def reset(self, V, C_Ca, E_Ca):
self.p.value = 1. / (1 + bm.exp(-(V + 10. - self.V_sh) / 4.))
self.q.value = 1. / (1. + bm.exp((V + 25. - self.V_sh) / 2.))

+ 249
- 0
brainpy/dyn/channels/IH.py View File

@@ -0,0 +1,249 @@
# -*- coding: utf-8 -*-

"""
This module implements hyperpolarization-activated cation channels.

"""

from typing import Union, Callable

import brainpy.math as bm
from brainpy.initialize import Initializer, parameter, variable
from brainpy.integrators import odeint, JointEq
from brainpy.types import Shape, Array
from brainpy.modes import Mode, BatchingMode, normal
from .base import IhChannel, CalciumChannel, Calcium

__all__ = [
'Ih_HM1992',
'Ih_De1996',
]


class Ih_HM1992(IhChannel):
r"""The hyperpolarization-activated cation current model propsoed by (Huguenard & McCormick, 1992) [1]_.

The hyperpolarization-activated cation current model is adopted from
(Huguenard, et, al., 1992) [1]_. Its dynamics is given by:

.. math::

\begin{aligned}
I_h &= g_{\mathrm{max}} p \\
\frac{dp}{dt} &= \phi \frac{p_{\infty} - p}{\tau_p} \\
p_{\infty} &=\frac{1}{1+\exp ((V+75) / 5.5)} \\
\tau_{p} &=\frac{1}{\exp (-0.086 V-14.59)+\exp (0.0701 V-1.87)}
\end{aligned}

where :math:`\phi=1` is a temperature-dependent factor.

Parameters
----------
g_max : float
The maximal conductance density (:math:`mS/cm^2`).
E : float
The reversal potential (mV).
phi : float
The temperature-dependent factor.

References
----------
.. [1] Huguenard, John R., and David A. McCormick. "Simulation of the currents
involved in rhythmic oscillations in thalamic relay neurons." Journal
of neurophysiology 68, no. 4 (1992): 1373-1383.

"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
g_max: Union[float, Array, Initializer, Callable] = 10.,
E: Union[float, Array, Initializer, Callable] = 43.,
phi: Union[float, Array, Initializer, Callable] = 1.,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
super(Ih_HM1992, self).__init__(size,
keep_size=keep_size,
name=name,
mode=mode)

# parameters
self.phi = parameter(phi, self.varshape, allow_none=False)
self.g_max = parameter(g_max, self.varshape, allow_none=False)
self.E = parameter(E, self.varshape, allow_none=False)

# variable
self.p = variable(bm.zeros, mode, self.varshape)

# function
self.integral = odeint(self.derivative, method=method)

def derivative(self, p, t, V):
return self.phi * (self.f_p_inf(V) - p) / self.f_p_tau(V)

def reset_state(self, V, batch_size=None):
self.p.value = self.f_p_inf(V)
if batch_size is not None:
assert self.p.shape[0] == batch_size

def update(self, tdi, V):
self.p.value = self.integral(self.p.value, tdi['t'], V, tdi['dt'])

def current(self, V):
return self.g_max * self.p * (self.E - V)

def f_p_inf(self, V):
return 1. / (1. + bm.exp((V + 75.) / 5.5))

def f_p_tau(self, V):
return 1. / (bm.exp(-0.086 * V - 14.59) + bm.exp(0.0701 * V - 1.87))


class Ih_De1996(IhChannel, CalciumChannel):
r"""The hyperpolarization-activated cation current model propsoed by (Destexhe, et al., 1996) [1]_.

The full kinetic schema was

.. math::

\begin{gathered}
C \underset{\beta(V)}{\stackrel{\alpha(V)}{\rightleftarrows}} O \\
P_{0}+2 \mathrm{Ca}^{2+} \underset{k_{2}}{\stackrel{k_{1}}{\rightleftarrows}} P_{1} \\
O+P_{1} \underset{k_{4}}{\rightleftarrows} O_{\mathrm{L}}
\end{gathered}

where the first reaction represents the voltage-dependent transitions of :math:`I_h` channels
between closed (C) and open (O) forms, with :math:`\alpha` and :math:`\beta` as transition rates.
The second reaction represents the biding of intracellular :math:`\mathrm{Ca^{2+}}` ions to a
regulating factor (:math:`P_0` for unbound and :math:`P_1` for bound) with four binding sites for
calcium and rates of :math:`k_1 = 2.5e^7\, mM^{-4} \, ms^{-1}` and :math:`k_2=4e-4 \, ms^{-1}`
(half-activation of 0.002 mM :math:`Ca^{2+}`). The calcium-bound form :math:`P_1` associates
with the open form of the channel, leading to a locked open form :math:`O_L`, with rates of
:math:`k_3=0.1 \, ms^{-1}` and :math:`k_4 = 0.001 \, ms^{-1}`.

The current is the proportional to the relative concentration of open channels

.. math::

I_h = g_h (O+g_{inc}O_L) (V - E_h)

with a maximal conductance of :math:`\bar{g}_{\mathrm{h}}=0.02 \mathrm{mS} / \mathrm{cm}^{2}`
and a reversal potential of :math:`E_{\mathrm{h}}=-40 \mathrm{mV}`. Because of the factor
:math:`g_{\text {inc }}=2`, the conductance of the calcium-bound open state of
:math:`I_{\mathrm{h}}` channels is twice that of the unbound open state. This produces an
augmentation of conductance after the binding of :math:`\mathrm{Ca}^{2+}`, as observed in
sino-atrial cells (Hagiwara and Irisawa 1989).

The rates of :math:`\alpha` and :math:`\beta` are:

.. math::

& \alpha = m_{\infty} / \tau_m \\
& \beta = (1-m_{\infty}) / \tau_m \\
& m_{\infty} = 1/(1+\exp((V+75-V_{sh})/5.5)) \\
& \tau_m = (5.3 + 267/(\exp((V+71.5-V_{sh})/14.2) + \exp(-(V+89-V_{sh})/11.6)))

and the temperature regulating factor :math:`\phi=2^{(T-24)/10}`.

References
----------
.. [1] Destexhe, Alain, et al. "Ionic mechanisms underlying synchronized
oscillations and propagating waves in a model of ferret thalamic
slices." Journal of neurophysiology 76.3 (1996): 2049-2070.
"""

master_type = Calcium

def __init__(
self,
size: Shape,
keep_size: bool = False,
E: Union[float, Array, Initializer, Callable] = -40.,
k2: Union[float, Array, Initializer, Callable] = 4e-4,
k4: Union[float, Array, Initializer, Callable] = 1e-3,
V_sh: Union[float, Array, Initializer, Callable] = 0.,
g_max: Union[float, Array, Initializer, Callable] = 0.02,
g_inc: Union[float, Array, Initializer, Callable] = 2.,
Ca_half: Union[float, Array, Initializer, Callable] = 2e-3,
T: Union[float, Array] = 36.,
T_base: Union[float, Array] = 3.,
phi: Union[float, Array, Initializer, Callable] = None,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
# IhChannel.__init__(self, size, name=name, keep_size=keep_size)
CalciumChannel.__init__(self,
size,
keep_size=keep_size,
name=name,
mode=mode)

# parameters
self.T = parameter(T, self.varshape, allow_none=False)
self.T_base = parameter(T_base, self.varshape, allow_none=False)
if phi is None:
self.phi = self.T_base ** ((self.T - 24.) / 10)
else:
self.phi = parameter(phi, self.varshape, allow_none=False)
self.E = parameter(E, self.varshape, allow_none=False)
self.k2 = parameter(k2, self.varshape, allow_none=False)
self.Ca_half = parameter(Ca_half, self.varshape, allow_none=False)
self.k1 = self.k2 / self.Ca_half ** 4
self.k4 = parameter(k4, self.varshape, allow_none=False)
self.k3 = self.k4 / 0.01
self.V_sh = parameter(V_sh, self.varshape, allow_none=False)
self.g_max = parameter(g_max, self.varshape, allow_none=False)
self.g_inc = parameter(g_inc, self.varshape, allow_none=False)

# variable
self.O = variable(bm.zeros, mode, self.varshape)
self.OL = variable(bm.zeros, mode, self.varshape)
self.P1 = variable(bm.zeros, mode, self.varshape)

# function
self.integral = odeint(JointEq(self.dO, self.dOL, self.dP1), method=method)

def dO(self, O, t, OL, V):
inf = self.f_inf(V)
tau = self.f_tau(V)
alpha = inf / tau
beta = (1 - inf) / tau
return alpha * (1 - O - OL) - beta * O

def dOL(self, OL, t, O, P1):
return self.k3 * P1 * O - self.k4 * OL

def dP1(self, P1, t, C_Ca):
return self.k1 * C_Ca ** 4 * (1 - P1) - self.k2 * P1

def update(self, tdi, V, C_Ca, E_Ca):
self.O.value = self.integral(self.O.value, self.OL.value, self.P1.value,
tdi['t'], V=V, C_Ca=C_Ca, dt=tdi['dt'])

def current(self, V, C_Ca, E_Ca):
return self.g_max * (self.O + self.g_inc * self.OL) * (self.E - V)

def reset_state(self, V, C_Ca, E_Ca, batch_size=None):
varshape = self.varshape if (batch_size is None) else ((batch_size,) + self.varshape)
self.P1.value = bm.broadcast_to(self.k1 * C_Ca ** 4 / (self.k1 * C_Ca ** 4 + self.k2), varshape)
inf = self.f_inf(V)
tau = self.f_tau(V)
alpha = inf / tau
beta = (1 - inf) / tau
self.O.value = alpha / (alpha + alpha * self.k3 * self.P1 / self.k4 + beta)
self.OL.value = self.k3 * self.P1 * self.O / self.k4
if batch_size is not None:
assert self.P1.shape[0] == batch_size
assert self.O.shape[0] == batch_size
assert self.OL.shape[0] == batch_size

def f_inf(self, V):
return 1 / (1 + bm.exp((V + 75 - self.V_sh) / 5.5))

def f_tau(self, V):
return (20. + 1000 / (bm.exp((V + 71.5 - self.V_sh) / 14.2) +
bm.exp(-(V + 89 - self.V_sh) / 11.6))) / self.phi

+ 0
- 95
brainpy/dyn/channels/Ih_channels.py View File

@@ -1,95 +0,0 @@
# -*- coding: utf-8 -*-

from typing import Union, Callable

import brainpy.math as bm
from brainpy.dyn.base import ConNeuGroup
from brainpy.initialize import Initializer, init_param
from brainpy.integrators import odeint
from brainpy.types import Shape, Tensor
from .base import IonChannel

__all__ = [
'IhChannel',
'Ih',
]


class IhChannel(IonChannel):
"""Base class for Ih channel models."""
master_cls = ConNeuGroup


class Ih(IhChannel):
r"""The hyperpolarization-activated cation current model.

The hyperpolarization-activated cation current model is adopted from (Huguenard, et, al., 1992) [1]_.
Its dynamics is given by:

.. math::

\begin{aligned}
I_h &= g_{\mathrm{max}} p
\\
\frac{dp}{dt} &= \phi \frac{p_{\infty} - p}{\tau_p}
\\
p_{\infty} &=\frac{1}{1+\exp ((V+75) / 5.5)}
\\
\tau_{p} &=\frac{1}{\exp (-0.086 V-14.59)+\exp (0.0701 V-1.87)}
\end{aligned}

where :math:`\phi=1` is a temperature-dependent factor.

Parameters
----------
g_max : float
The maximal conductance density (:math:`mS/cm^2`).
E : float
The reversal potential (mV).
phi : float
The temperature-dependent factor.

References
----------
.. [1] Huguenard, John R., and David A. McCormick. "Simulation of the currents
involved in rhythmic oscillations in thalamic relay neurons." Journal
of neurophysiology 68, no. 4 (1992): 1373-1383.

"""

def __init__(
self,
size: Shape,
g_max: Union[float, Tensor, Initializer, Callable]=10.,
E: Union[float, Tensor, Initializer, Callable]=-90.,
phi: Union[float, Tensor, Initializer, Callable]=1.,
method: str = 'exp_auto',
name: str = None
):
super(Ih, self).__init__(size, name=name)

# parameters
self.phi = init_param(phi, self.num, allow_none=False)
self.g_max = init_param(g_max, self.num, allow_none=False)
self.E = init_param(E, self.num, allow_none=False)

# variable
self.p = bm.Variable(bm.zeros(self.num))

# function
self.integral = odeint(self.derivative, method=method)

def derivative(self, p, t, V):
p_inf = 1. / (1. + bm.exp((V + 75.) / 5.5))
p_tau = 1. / (bm.exp(-0.086 * V - 14.59) + bm.exp(0.0701 * V - 1.87))
return self.phi * (p_inf - p) / p_tau

def reset(self, V):
self.p.value = 1. / (1. + bm.exp((V + 75.) / 5.5))

def update(self, t, dt, V):
self.p.value = self.integral(self.p, t, V, dt=dt)

def current(self, V):
g = self.g_max * self.p
return g * (self.E - V)

+ 1021
- 0
brainpy/dyn/channels/K.py View File

@@ -0,0 +1,1021 @@
# -*- coding: utf-8 -*-

"""
This module implements voltage-dependent potassium channels.

"""

from typing import Union, Callable, Optional

import brainpy.math as bm
from brainpy.initialize import Initializer, parameter, variable
from brainpy.integrators import odeint, JointEq
from brainpy.types import Shape, Array
from brainpy.modes import Mode, BatchingMode, normal
from .base import PotassiumChannel

__all__ = [
'IK_p4_markov',
'IKDR_Ba2002',
'IK_TM1991',
'IK_HH1952',

'IKA_p4q_ss',
'IKA1_HM1992',
'IKA2_HM1992',

'IKK2_pq_ss',
'IKK2A_HM1992',
'IKK2B_HM1992',

'IKNI_Ya1989',
]


class IK_p4_markov(PotassiumChannel):
r"""The delayed rectifier potassium channel of :math:`p^4`
current which described with first-order Markov chain.

This general potassium current model should have the form of

.. math::

\begin{aligned}
I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\
\frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p)
\end{aligned}

where :math:`\phi` is a temperature-dependent factor.

Parameters
----------
size: int, sequence of int
The object size.
keep_size: bool
Whether we use `size` to initialize the variable. Otherwise, variable shape
will be initialized as `num`.
g_max : float, JaxArray, ndarray, Initializer, Callable
The maximal conductance density (:math:`mS/cm^2`).
E : float, JaxArray, ndarray, Initializer, Callable
The reversal potential (mV).
phi : float, JaxArray, ndarray, Initializer, Callable
The temperature-dependent factor.
method: str
The numerical integration method.
name: str
The object name.

"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
E: Union[float, Array, Initializer, Callable] = -90.,
g_max: Union[float, Array, Initializer, Callable] = 10.,
phi: Union[float, Array, Initializer, Callable] = 1.,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
super(IK_p4_markov, self).__init__(size,
keep_size=keep_size,
name=name,
mode=mode)

self.E = parameter(E, self.varshape, allow_none=False)
self.g_max = parameter(g_max, self.varshape, allow_none=False)
self.phi = parameter(phi, self.varshape, allow_none=False)

# variables
self.p = variable(bm.zeros, mode, self.varshape)

# function
self.integral = odeint(self.derivative, method=method)

def derivative(self, p, t, V):
return self.phi * (self.f_p_alpha(V) * (1. - p) - self.f_p_beta(V) * p)

def update(self, tdi, V):
self.p.value = self.integral(self.p, tdi['t'], V, tdi['dt'])

def current(self, V):
return self.g_max * self.p ** 4 * (self.E - V)

def reset_state(self, V, batch_size=None):
alpha = self.f_p_alpha(V)
beta = self.f_p_beta(V)
self.p.value = alpha / (alpha + beta)
if batch_size is not None:
assert self.p.shape[0] == batch_size

def f_p_alpha(self, V):
raise NotImplementedError

def f_p_beta(self, V):
raise NotImplementedError


class IKDR_Ba2002(IK_p4_markov):
r"""The delayed rectifier potassium channel current.

The potassium current model is adopted from (Bazhenov, et, al. 2002) [1]_.
It's dynamics is given by:

.. math::

\begin{aligned}
I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\
\frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\
\alpha_{p} &=\frac{0.032\left(V-V_{sh}-15\right)}{1-\exp \left(-\left(V-V_{sh}-15\right) / 5\right)} \\
\beta_p &= 0.5 \exp \left(-\left(V-V_{sh}-10\right) / 40\right)
\end{aligned}

where :math:`\phi` is a temperature-dependent factor, which is given by
:math:`\phi=3^{\frac{T-36}{10}}` (:math:`T` is the temperature in Celsius).

Parameters
----------
size: int, sequence of int
The object size.
keep_size: bool
Whether we use `size` to initialize the variable. Otherwise, variable shape
will be initialized as `num`.
g_max : float, JaxArray, ndarray, Initializer, Callable
The maximal conductance density (:math:`mS/cm^2`).
E : float, JaxArray, ndarray, Initializer, Callable
The reversal potential (mV).
T_base : float, JaxArray, ndarray
The base of temperature factor.
T : float, JaxArray, ndarray, Initializer, Callable
The temperature (Celsius, :math:`^{\circ}C`).
V_sh : float, JaxArray, ndarray, Initializer, Callable
The shift of the membrane potential to spike.
method: str
The numerical integration method.
name: str
The object name.

References
----------
.. [1] Bazhenov, Maxim, et al. "Model of thalamocortical slow-wave sleep oscillations
and transitions to activated states." Journal of neuroscience 22.19 (2002): 8691-8704.

"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
E: Union[float, Array, Initializer, Callable] = -90.,
g_max: Union[float, Array, Initializer, Callable] = 10.,
V_sh: Union[float, Array, Initializer, Callable] = -50.,
T_base: Union[float, Array] = 3.,
T: Union[float, Array] = 36.,
phi: Optional[Union[float, Array, Initializer, Callable]] = None,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
phi = T_base ** ((T - 36) / 10) if phi is None else phi
super(IKDR_Ba2002, self).__init__(size,
keep_size=keep_size,
name=name,
method=method,
g_max=g_max,
phi=phi,
E=E,
mode=mode)

# parameters
self.T = parameter(T, self.varshape, allow_none=False)
self.T_base = parameter(T_base, self.varshape, allow_none=False)
self.V_sh = parameter(V_sh, self.varshape, allow_none=False)

def f_p_alpha(self, V):
tmp = V - self.V_sh - 15.
return 0.032 * tmp / (1. - bm.exp(-tmp / 5.))

def f_p_beta(self, V):
return 0.5 * bm.exp(-(V - self.V_sh - 10.) / 40.)


class IK_TM1991(IK_p4_markov):
r"""The potassium channel described by (Traub and Miles, 1991) [1]_.

The dynamics of this channel is given by:

.. math::

\begin{aligned}
I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\
\frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\
\alpha_{p} &= 0.032 \frac{(15 - V + V_{sh})}{(\exp((15 - V + V_{sh}) / 5) - 1.)} \\
\beta_p &= 0.5 * \exp((10 - V + V_{sh}) / 40)
\end{aligned}

where :math:`V_{sh}` is the membrane shift (default -63 mV), and
:math:`\phi` is the temperature-dependent factor (default 1.).

Parameters
----------
size: int, sequence of int
The geometry size.
g_max : float, JaxArray, ndarray, Initializer, Callable
The maximal conductance density (:math:`mS/cm^2`).
E : float, JaxArray, ndarray, Initializer, Callable
The reversal potential (mV).
method: str
The numerical integration method.
name: str
The object name.

References
----------
.. [1] Traub, Roger D., and Richard Miles. Neuronal networks of the hippocampus.
Vol. 777. Cambridge University Press, 1991.

See Also
--------
INa_TM1991
"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
E: Union[float, Array, Initializer, Callable] = -90.,
g_max: Union[float, Array, Initializer, Callable] = 10.,
phi: Union[float, Array, Initializer, Callable] = 1.,
V_sh: Union[int, float, Array, Initializer, Callable] = -60.,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
super(IK_TM1991, self).__init__(size,
keep_size=keep_size,
name=name,
method=method,
phi=phi,
E=E,
g_max=g_max,
mode=mode)
self.V_sh = parameter(V_sh, self.varshape, allow_none=False)

def f_p_alpha(self, V):
c = 15 - V + self.V_sh
return 0.032 * c / (bm.exp(c / 5) - 1.)

def f_p_beta(self, V):
return 0.5 * bm.exp((10 - V + self.V_sh) / 40)


class IK_HH1952(IK_p4_markov):
r"""The potassium channel described by Hodgkin–Huxley model [1]_.

The dynamics of this channel is given by:

.. math::

\begin{aligned}
I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\
\frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\
\alpha_{p} &= \frac{0.01 (V -V_{sh} + 10)}{1-\exp \left(-\left(V-V_{sh}+ 10\right) / 10\right)} \\
\beta_p &= 0.125 \exp \left(-\left(V-V_{sh}+20\right) / 80\right)
\end{aligned}

where :math:`V_{sh}` is the membrane shift (default -45 mV), and
:math:`\phi` is the temperature-dependent factor (default 1.).

Parameters
----------
size: int, sequence of int
The geometry size.
g_max : float, JaxArray, ndarray, Initializer, Callable
The maximal conductance density (:math:`mS/cm^2`).
E : float, JaxArray, ndarray, Initializer, Callable
The reversal potential (mV).
method: str
The numerical integration method.
name: str
The object name.

References
----------
.. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of
membrane current and its application to conduction and excitation in
nerve." The Journal of physiology 117.4 (1952): 500.

See Also
--------
INa_HH1952
"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
E: Union[float, Array, Initializer, Callable] = -90.,
g_max: Union[float, Array, Initializer, Callable] = 10.,
phi: Union[float, Array, Initializer, Callable] = 1.,
V_sh: Union[int, float, Array, Initializer, Callable] = -45.,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
super(IK_HH1952, self).__init__(size,
keep_size=keep_size,
name=name,
method=method,
phi=phi,
E=E,
g_max=g_max,
mode=mode)
self.V_sh = parameter(V_sh, self.varshape, allow_none=False)

def f_p_alpha(self, V):
temp = V - self.V_sh + 10
return 0.01 * temp / (1 - bm.exp(-temp / 10))

def f_p_beta(self, V):
return 0.125 * bm.exp(-(V - self.V_sh + 20) / 80)


class IKA_p4q_ss(PotassiumChannel):
r"""The rapidly inactivating Potassium channel of :math:`p^4q`
current which described with steady-state format.

This model is developed according to the average behavior of
rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [3]_.

.. math::

&IA = g_{\mathrm{max}} p^4 q (E-V) \\
&\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\
&\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\

where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.).

Parameters
----------
size: int, sequence of int
The geometry size.
method: str
The numerical integration method.
name: str
The object name.
g_max : float, JaxArray, ndarray, Initializer, Callable
The maximal conductance density (:math:`mS/cm^2`).
E : float, JaxArray, ndarray, Initializer, Callable
The reversal potential (mV).
phi_p : optional, float, Array, Callable, Initializer
The temperature factor for channel :math:`p`.
phi_q : optional, float, Array, Callable, Initializer
The temperature factor for channel :math:`q`.

References
----------
.. [2] Huguenard, John R., and David A. McCormick. "Simulation of the
currents involved in rhythmic oscillations in thalamic relay
neurons." Journal of neurophysiology 68.4 (1992): 1373-1383.
.. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a
TEA-sensitive K current in acutely isolated rat thalamic relay
neurons." Journal of neurophysiology 66.4 (1991): 1316-1328.
"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
E: Union[float, Array, Initializer, Callable] = -90.,
g_max: Union[float, Array, Initializer, Callable] = 10.,
phi_p: Union[float, Array, Initializer, Callable] = 1.,
phi_q: Union[float, Array, Initializer, Callable] = 1.,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
super(IKA_p4q_ss, self).__init__(size,
keep_size=keep_size,
name=name,
mode=mode)

# parameters
self.E = parameter(E, self.varshape, allow_none=False)
self.g_max = parameter(g_max, self.varshape, allow_none=False)
self.phi_p = parameter(phi_p, self.varshape, allow_none=False)
self.phi_q = parameter(phi_q, self.varshape, allow_none=False)

# variables
self.p = variable(bm.zeros, mode, self.varshape)
self.q = variable(bm.zeros, mode, self.varshape)

# function
self.integral = odeint(JointEq(self.dp, self.dq), method=method)

def dp(self, p, t, V):
return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V)

def dq(self, q, t, V):
return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V)

def update(self, tdi, V):
t, dt = tdi['t'], tdi['dt']
self.p.value, self.q.value = self.integral(self.p.value, self.q.value, t, V, dt)

def current(self, V):
return self.g_max * self.p ** 4 * self.q * (self.E - V)

def reset_state(self, V, batch_size=None):
self.p.value = self.f_p_inf(V)
self.q.value = self.f_q_inf(V)
if batch_size is not None:
assert self.p.shape[0] == batch_size
assert self.q.shape[0] == batch_size

def f_p_inf(self, V):
raise NotImplementedError

def f_p_tau(self, V):
raise NotImplementedError

def f_q_inf(self, V):
raise NotImplementedError

def f_q_tau(self, V):
raise NotImplementedError


class IKA1_HM1992(IKA_p4q_ss):
r"""The rapidly inactivating Potassium channel (IA1) model proposed by (Huguenard & McCormick, 1992) [2]_.

This model is developed according to the average behavior of
rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [1]_.

.. math::

&IA = g_{\mathrm{max}} p^4 q (E-V) \\
&\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\
&p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 60)/8.5]} \\
&\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}+35.8}{19.7}\right)+ \exp \left(\frac{V -V_{sh}+79.7}{-12.7}\right)}+0.37 \\
&\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\
&q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 78)/6]} \\
&\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+46)/5.) + \exp((V -V_{sh}+238)/-37.5)} \quad V<(-63+V_{sh})\, mV \\
\tau_{q} = 19 \quad V \geq (-63 + V_{sh})\, mV \end{array}

where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.).

Parameters
----------
size: int, sequence of int
The geometry size.
method: str
The numerical integration method.
name: str
The object name.
g_max : float, JaxArray, ndarray, Initializer, Callable
The maximal conductance density (:math:`mS/cm^2`).
E : float, JaxArray, ndarray, Initializer, Callable
The reversal potential (mV).
V_sh : float, Array, Callable, Initializer
The membrane potential shift.
phi_p : optional, float, Array, Callable, Initializer
The temperature factor for channel :math:`p`.
phi_q : optional, float, Array, Callable, Initializer
The temperature factor for channel :math:`q`.

References
----------
.. [2] Huguenard, John R., and David A. McCormick. "Simulation of the
currents involved in rhythmic oscillations in thalamic relay
neurons." Journal of neurophysiology 68.4 (1992): 1373-1383.
.. [1] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a
TEA-sensitive K current in acutely isolated rat thalamic relay
neurons." Journal of neurophysiology 66.4 (1991): 1316-1328.

See Also
--------
IKA2_HM1992
"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
E: Union[float, Array, Initializer, Callable] = -90.,
g_max: Union[float, Array, Initializer, Callable] = 30.,
V_sh: Union[float, Array, Initializer, Callable] = 0.,
phi_p: Union[float, Array, Initializer, Callable] = 1.,
phi_q: Union[float, Array, Initializer, Callable] = 1.,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
super(IKA1_HM1992, self).__init__(size,
keep_size=keep_size,
name=name,
method=method,
E=E,
g_max=g_max,
phi_p=phi_p,
phi_q=phi_q,
mode=mode)

# parameters
self.V_sh = parameter(V_sh, self.varshape, allow_none=False)

def f_p_inf(self, V):
return 1. / (1. + bm.exp(-(V - self.V_sh + 60.) / 8.5))

def f_p_tau(self, V):
return 1. / (bm.exp((V - self.V_sh + 35.8) / 19.7) +
bm.exp(-(V - self.V_sh + 79.7) / 12.7)) + 0.37

def f_q_inf(self, V):
return 1. / (1. + bm.exp((V - self.V_sh + 78.) / 6.))

def f_q_tau(self, V):
return bm.where(V < -63 + self.V_sh,
1. / (bm.exp((V - self.V_sh + 46.) / 5.) +
bm.exp(-(V - self.V_sh + 238.) / 37.5)),
19.)


class IKA2_HM1992(IKA_p4q_ss):
r"""The rapidly inactivating Potassium channel (IA2) model proposed by (Huguenard & McCormick, 1992) [2]_.

This model is developed according to the average behavior of
rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [1]_.

.. math::

&IA = g_{\mathrm{max}} p^4 q (E-V) \\
&\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\
&p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 36)/20.]} \\
&\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}+35.8}{19.7}\right)+ \exp \left(\frac{V -V_{sh}+79.7}{-12.7}\right)}+0.37 \\
&\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\
&q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 78)/6]} \\
&\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+46)/5.) + \exp((V -V_{sh}+238)/-37.5)} \quad V<(-63+V_{sh})\, mV \\
\tau_{q} = 19 \quad V \geq (-63 + V_{sh})\, mV \end{array}

where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.).

Parameters
----------
size: int, sequence of int
The geometry size.
method: str
The numerical integration method.
name: str
The object name.
g_max : float, JaxArray, ndarray, Initializer, Callable
The maximal conductance density (:math:`mS/cm^2`).
E : float, JaxArray, ndarray, Initializer, Callable
The reversal potential (mV).
V_sh : float, Array, Callable, Initializer
The membrane potential shift.
phi_p : optional, float, Array, Callable, Initializer
The temperature factor for channel :math:`p`.
phi_q : optional, float, Array, Callable, Initializer
The temperature factor for channel :math:`q`.

References
----------
.. [2] Huguenard, John R., and David A. McCormick. "Simulation of the
currents involved in rhythmic oscillations in thalamic relay
neurons." Journal of neurophysiology 68.4 (1992): 1373-1383.
.. [1] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a
TEA-sensitive K current in acutely isolated rat thalamic relay
neurons." Journal of neurophysiology 66.4 (1991): 1316-1328.

See Also
--------
IKA1_HM1992
"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
E: Union[float, Array, Initializer, Callable] = -90.,
g_max: Union[float, Array, Initializer, Callable] = 20.,
V_sh: Union[float, Array, Initializer, Callable] = 0.,
phi_p: Union[float, Array, Initializer, Callable] = 1.,
phi_q: Union[float, Array, Initializer, Callable] = 1.,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
super(IKA2_HM1992, self).__init__(size,
keep_size=keep_size,
name=name,
method=method,
E=E,
g_max=g_max,
phi_q=phi_q,
phi_p=phi_p,
mode=mode)

# parameters
self.V_sh = parameter(V_sh, self.varshape, allow_none=False)

def f_p_inf(self, V):
return 1. / (1. + bm.exp(-(V - self.V_sh + 36.) / 20.))

def f_p_tau(self, V):
return 1. / (bm.exp((V - self.V_sh + 35.8) / 19.7) +
bm.exp(-(V - self.V_sh + 79.7) / 12.7)) + 0.37

def f_q_inf(self, V):
return 1. / (1. + bm.exp((V - self.V_sh + 78.) / 6.))

def f_q_tau(self, V):
return bm.where(V < -63 + self.V_sh,
1. / (bm.exp((V - self.V_sh + 46.) / 5.) +
bm.exp(-(V - self.V_sh + 238.) / 37.5)),
19.)


class IKK2_pq_ss(PotassiumChannel):
r"""The slowly inactivating Potassium channel of :math:`pq`
current which described with steady-state format.

The dynamics of the model is given as [2]_ [3]_.

.. math::

&IK2 = g_{\mathrm{max}} p q (E-V) \\
&\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\
&\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\

where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.).

Parameters
----------
size: int, sequence of int
The geometry size.
method: str
The numerical integration method.
name: str
The object name.
g_max : float, JaxArray, ndarray, Initializer, Callable
The maximal conductance density (:math:`mS/cm^2`).
E : float, JaxArray, ndarray, Initializer, Callable
The reversal potential (mV).
phi_p : optional, float, Array, Callable, Initializer
The temperature factor for channel :math:`p`.
phi_q : optional, float, Array, Callable, Initializer
The temperature factor for channel :math:`q`.

References
----------
.. [2] Huguenard, John R., and David A. McCormick. "Simulation of the
currents involved in rhythmic oscillations in thalamic relay
neurons." Journal of neurophysiology 68.4 (1992): 1373-1383.
.. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a
TEA-sensitive K current in acutely isolated rat thalamic relay
neurons." Journal of neurophysiology 66.4 (1991): 1316-1328.

"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
E: Union[float, Array, Initializer, Callable] = -90.,
g_max: Union[float, Array, Initializer, Callable] = 10.,
phi_p: Union[float, Array, Initializer, Callable] = 1.,
phi_q: Union[float, Array, Initializer, Callable] = 1.,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
super(IKK2_pq_ss, self).__init__(size,
keep_size=keep_size,
name=name,
mode=mode)

# parameters
self.E = parameter(E, self.varshape, allow_none=False)
self.g_max = parameter(g_max, self.varshape, allow_none=False)
self.phi_p = parameter(phi_p, self.varshape, allow_none=False)
self.phi_q = parameter(phi_q, self.varshape, allow_none=False)

# variables
self.p = variable(bm.zeros, mode, self.varshape)
self.q = variable(bm.zeros, mode, self.varshape)

# function
self.integral = odeint(JointEq(self.dp, self.dq), method=method)

def dp(self, p, t, V):
return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V)

def dq(self, q, t, V):
return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V)

def update(self, tdi, V):
t, dt = tdi['t'], tdi['dt']
self.p.value, self.q.value = self.integral(self.p.value, self.q.value, t, V, dt)

def current(self, V):
return self.g_max * self.p * self.q * (self.E - V)

def reset_state(self, V, batch_size=None):
self.p.value = self.f_p_inf(V)
self.q.value = self.f_q_inf(V)
if batch_size is not None:
assert self.p.shape[0] == batch_size
assert self.q.shape[0] == batch_size

def f_p_inf(self, V):
raise NotImplementedError

def f_p_tau(self, V):
raise NotImplementedError

def f_q_inf(self, V):
raise NotImplementedError

def f_q_tau(self, V):
raise NotImplementedError


class IKK2A_HM1992(IKK2_pq_ss):
r"""The slowly inactivating Potassium channel (IK2a) model proposed by (Huguenard & McCormick, 1992) [2]_.

The dynamics of the model is given as [2]_ [3]_.

.. math::

&IK2 = g_{\mathrm{max}} p q (E-V) \\
&\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\
&p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 43)/17]} \\
&\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}-81.}{25.6}\right)+
\exp \left(\frac{V -V_{sh}+132}{-18}\right)}+9.9 \\
&\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\
&q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 59)/10.6]} \\
& \tau_{q} = \frac{1}{\exp((V -V_{sh}+1329)/200.) + \exp((V -V_{sh}+130)/-7.1)} + 120 \\

where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.).

Parameters
----------
size: int, sequence of int
The geometry size.
method: str
The numerical integration method.
name: str
The object name.
g_max : float, JaxArray, ndarray, Initializer, Callable
The maximal conductance density (:math:`mS/cm^2`).
E : float, JaxArray, ndarray, Initializer, Callable
The reversal potential (mV).
V_sh : float, Array, Callable, Initializer
The membrane potential shift.
phi_p : optional, float, Array, Callable, Initializer
The temperature factor for channel :math:`p`.
phi_q : optional, float, Array, Callable, Initializer
The temperature factor for channel :math:`q`.

References
----------
.. [2] Huguenard, John R., and David A. McCormick. "Simulation of the
currents involved in rhythmic oscillations in thalamic relay
neurons." Journal of neurophysiology 68.4 (1992): 1373-1383.
.. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a
TEA-sensitive K current in acutely isolated rat thalamic relay
neurons." Journal of neurophysiology 66.4 (1991): 1316-1328.

"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
E: Union[float, Array, Initializer, Callable] = -90.,
g_max: Union[float, Array, Initializer, Callable] = 10.,
V_sh: Union[float, Array, Initializer, Callable] = 0.,
phi_p: Union[float, Array, Initializer, Callable] = 1.,
phi_q: Union[float, Array, Initializer, Callable] = 1.,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
super(IKK2A_HM1992, self).__init__(size,
keep_size=keep_size,
name=name,
method=method,
phi_p=phi_p,
phi_q=phi_q,
g_max=g_max,
E=E,
mode=mode)

# parameters
self.V_sh = parameter(V_sh, self.varshape, allow_none=False)

def f_p_inf(self, V):
raise 1. / (1. + bm.exp(-(V - self.V_sh + 43.) / 17.))

def f_p_tau(self, V):
return 1. / (bm.exp((V - self.V_sh - 81.) / 25.6) +
bm.exp(-(V - self.V_sh + 132) / 18.)) + 9.9

def f_q_inf(self, V):
raise 1. / (1. + bm.exp((V - self.V_sh + 58.) / 10.6))

def f_q_tau(self, V):
raise 1. / (bm.exp((V - self.V_sh - 1329.) / 200.) +
bm.exp(-(V - self.V_sh + 130.) / 7.1))


class IKK2B_HM1992(IKK2_pq_ss):
r"""The slowly inactivating Potassium channel (IK2b) model proposed by (Huguenard & McCormick, 1992) [2]_.

The dynamics of the model is given as [2]_ [3]_.

.. math::

&IK2 = g_{\mathrm{max}} p q (E-V) \\
&\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\
&p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 43)/17]} \\
&\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}-81.}{25.6}\right)+
\exp \left(\frac{V -V_{sh}+132}{-18}\right)}+9.9 \\
&\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\
&q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 59)/10.6]} \\
&\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+1329)/200.) +
\exp((V -V_{sh}+130)/-7.1)} + 120 \quad V<(-70+V_{sh})\, mV \\
\tau_{q} = 8.9 \quad V \geq (-70 + V_{sh})\, mV \end{array}

where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.).

Parameters
----------
size: int, sequence of int
The geometry size.
method: str
The numerical integration method.
name: str
The object name.
g_max : float, JaxArray, ndarray, Initializer, Callable
The maximal conductance density (:math:`mS/cm^2`).
E : float, JaxArray, ndarray, Initializer, Callable
The reversal potential (mV).
V_sh : float, Array, Callable, Initializer
The membrane potential shift.
phi_p : optional, float, Array, Callable, Initializer
The temperature factor for channel :math:`p`.
phi_q : optional, float, Array, Callable, Initializer
The temperature factor for channel :math:`q`.

References
----------
.. [2] Huguenard, John R., and David A. McCormick. "Simulation of the
currents involved in rhythmic oscillations in thalamic relay
neurons." Journal of neurophysiology 68.4 (1992): 1373-1383.
.. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a
TEA-sensitive K current in acutely isolated rat thalamic relay
neurons." Journal of neurophysiology 66.4 (1991): 1316-1328.

"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
E: Union[float, Array, Initializer, Callable] = -90.,
g_max: Union[float, Array, Initializer, Callable] = 10.,
V_sh: Union[float, Array, Initializer, Callable] = 0.,
phi_p: Union[float, Array, Initializer, Callable] = 1.,
phi_q: Union[float, Array, Initializer, Callable] = 1.,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
super(IKK2B_HM1992, self).__init__(size,
keep_size=keep_size,
name=name,
method=method,
phi_p=phi_p,
phi_q=phi_q,
g_max=g_max,
E=E,
mode=mode)

# parameters
self.V_sh = parameter(V_sh, self.varshape, allow_none=False)

def f_p_inf(self, V):
raise 1. / (1. + bm.exp(-(V - self.V_sh + 43.) / 17.))

def f_p_tau(self, V):
return 1. / (bm.exp((V - self.V_sh - 81.) / 25.6) +
bm.exp(-(V - self.V_sh + 132) / 18.)) + 9.9

def f_q_inf(self, V):
raise 1. / (1. + bm.exp((V - self.V_sh + 58.) / 10.6))

def f_q_tau(self, V):
raise bm.where(V < -70 + self.V_sh,
1. / (bm.exp((V - self.V_sh - 1329.) / 200.) +
bm.exp(-(V - self.V_sh + 130.) / 7.1)),
8.9)


class IKNI_Ya1989(PotassiumChannel):
r"""A slow non-inactivating K+ current described by Yamada et al. (1989) [1]_.

This slow potassium current can effectively account for spike-frequency adaptation.

.. math::

\begin{aligned}
&I_{M}=\bar{g}_{M} p\left(V-E_{K}\right) \\
&\frac{\mathrm{d} p}{\mathrm{~d} t}=\left(p_{\infty}(V)-p\right) / \tau_{p}(V) \\
&p_{\infty}(V)=\frac{1}{1+\exp [-(V-V_{sh}+35) / 10]} \\
&\tau_{p}(V)=\frac{\tau_{\max }}{3.3 \exp [(V-V_{sh}+35) / 20]+\exp [-(V-V_{sh}+35) / 20]}
\end{aligned}

where :math:`\bar{g}_{M}` was :math:`0.004 \mathrm{mS} / \mathrm{cm}^{2}` and
:math:`\tau_{\max }=4 \mathrm{~s}`, unless stated otherwise.

Parameters
----------
size: int, sequence of int
The geometry size.
method: str
The numerical integration method.
name: str
The object name.
g_max : float, JaxArray, ndarray, Initializer, Callable
The maximal conductance density (:math:`mS/cm^2`).
E : float, JaxArray, ndarray, Initializer, Callable
The reversal potential (mV).
V_sh : float, Array, Callable, Initializer
The membrane potential shift.
phi_p : optional, float, Array, Callable, Initializer
The temperature factor for channel :math:`p`.
tau_max: float, Array, Callable, Initializer
The :math:`tau_{\max}` parameter.

References
----------
.. [1] Yamada, Walter M. "Multiple channels and calcium dynamics." Methods in neuronal modeling (1989): 97-133.

"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
E: Union[float, Array, Initializer, Callable] = -90.,
g_max: Union[float, Array, Initializer, Callable] = 0.004,
phi_p: Union[float, Array, Initializer, Callable] = 1.,
phi_q: Union[float, Array, Initializer, Callable] = 1.,
tau_max: Union[float, Array, Initializer, Callable] = 4e3,
V_sh: Union[float, Array, Initializer, Callable] = 0.,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
super(IKNI_Ya1989, self).__init__(size,
keep_size=keep_size,
name=name,
mode=mode)

# parameters
self.E = parameter(E, self.varshape, allow_none=False)
self.g_max = parameter(g_max, self.varshape, allow_none=False)
self.tau_max = parameter(tau_max, self.varshape, allow_none=False)
self.V_sh = parameter(V_sh, self.varshape, allow_none=False)
self.phi_p = parameter(phi_p, self.varshape, allow_none=False)
self.phi_q = parameter(phi_q, self.varshape, allow_none=False)

# variables
self.p = variable(bm.zeros, mode, self.varshape)

# function
self.integral = odeint(self.dp, method=method)

def dp(self, p, t, V):
return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V)

def update(self, tdi, V):
t, dt = tdi['t'], tdi['dt']
self.p.value = self.integral(self.p.value, t, V, dt)

def current(self, V):
return self.g_max * self.p * (self.E - V)

def reset_state(self, V, batch_size=None):
self.p.value = self.f_p_inf(V)
if batch_size is not None:
assert self.p.shape[0] == batch_size

def f_p_inf(self, V):
raise 1. / (1. + bm.exp(-(V - self.V_sh + 35.) / 10.))

def f_p_tau(self, V):
temp = V - self.V_sh + 35.
raise self.tau_max / (3.3 * bm.exp(temp / 20.) + bm.exp(-temp / 20.))

+ 127
- 0
brainpy/dyn/channels/KCa.py View File

@@ -0,0 +1,127 @@
# -*- coding: utf-8 -*-


"""
This module implements calcium-dependent potassium channels.

"""

from typing import Union, Callable

import brainpy.math as bm
from brainpy.initialize import Initializer, parameter, variable
from brainpy.integrators.ode import odeint
from brainpy.types import Shape, Array
from brainpy.modes import Mode, BatchingMode, normal
from .base import Calcium, CalciumChannel, PotassiumChannel

__all__ = [
'IAHP_De1994',
]


class IAHP_De1994(PotassiumChannel, CalciumChannel):
r"""The calcium-dependent potassium current model proposed by (Destexhe, et al., 1994) [1]_.

Both in vivo (Contreras et al. 1993; Mulle et al. 1986) and in
vitro recordings (Avanzini et al. 1989) show the presence of a
marked after-hyper-polarization (AHP) after each burst of the RE
cell. This slow AHP is mediated by a slow :math:`Ca^{2+}`-dependent K+
current (Bal and McCormick 1993). (Destexhe, et al., 1994) adopted a
modified version of a model of :math:`I_{KCa}` introduced previously (Yamada et al.
1989) that requires the binding of :math:`nCa^{2+}` to open the channel

.. math::

(\text { closed })+n \mathrm{Ca}_{i}^{2+} \underset{\beta}{\stackrel{\alpha}{\rightleftharpoons}(\text { open })

where :math:`Ca_i^{2+}` is the intracellular calcium and :math:`\alpha` and
:math:`\beta` are rate constants. The ionic current is then given by

.. math::

\begin{aligned}
I_{AHP} &= g_{\mathrm{max}} p^2 (V - E_K) \\
{dp \over dt} &= \phi {p_{\infty}(V, [Ca^{2+}]_i) - p \over \tau_p(V, [Ca^{2+}]_i)} \\
p_{\infty} &=\frac{\alpha[Ca^{2+}]_i^n}{\left(\alpha[Ca^{2+}]_i^n + \beta\right)} \\
\tau_p &=\frac{1}{\left(\alpha[Ca^{2+}]_i +\beta\right)}
\end{aligned}

where :math:`E` is the reversal potential, :math:`g_{max}` is the maximum conductance,
:math:`[Ca^{2+}]_i` is the intracellular Calcium concentration.
The values :math:`n=2, \alpha=48 \mathrm{~ms}^{-1} \mathrm{mM}^{-2}` and
:math:`\beta=0.03 \mathrm{~ms}^{-1}` yielded AHPs very similar to those RE cells
recorded in vivo and in vitro.

Parameters
----------
g_max : float
The maximal conductance density (:math:`mS/cm^2`).
E : float
The reversal potential (mV).

References
----------

.. [1] Destexhe, Alain, et al. "A model of spindle rhythmicity in the isolated
thalamic reticular nucleus." Journal of neurophysiology 72.2 (1994): 803-818.

"""

'''The type of the master object.'''
master_type = Calcium

def __init__(
self,
size: Shape,
keep_size: bool = False,
E: Union[float, Array, Initializer, Callable] = -95.,
n: Union[float, Array, Initializer, Callable] = 2,
g_max: Union[float, Array, Initializer, Callable] = 10.,
alpha: Union[float, Array, Initializer, Callable] = 48.,
beta: Union[float, Array, Initializer, Callable] = 0.09,
phi: Union[float, Array, Initializer, Callable] = 1.,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
CalciumChannel.__init__(self,
size=size,
keep_size=keep_size,
name=name,
mode=mode)

# parameters
self.E = parameter(E, self.varshape, allow_none=False)
self.g_max = parameter(g_max, self.varshape, allow_none=False)
self.n = parameter(n, self.varshape, allow_none=False)
self.alpha = parameter(alpha, self.varshape, allow_none=False)
self.beta = parameter(beta, self.varshape, allow_none=False)
self.phi = parameter(phi, self.varshape, allow_none=False)

# variables
self.p = variable(bm.zeros, mode, self.varshape)

# function
self.integral = odeint(self.dp, method=method)

def dp(self, p, t, C_Ca):
C2 = self.alpha * bm.power(C_Ca, self.n)
C3 = C2 + self.beta
return self.phi * (C2 / C3 - p) * C3

def update(self, tdi, V, C_Ca, E_Ca):
t, dt = tdi['t'], tdi['dt']
self.p.value = self.integral(self.p, t, C_Ca=C_Ca, dt=dt)

def current(self, V, C_Ca, E_Ca):
return self.g_max * self.p * self.p * (self.E - V)

def reset_state(self, V, C_Ca, E_Ca, batch_size=None):
C2 = self.alpha * bm.power(C_Ca, self.n)
C3 = C2 + self.beta
if batch_size is None:
self.p.value = bm.broadcast_to(C2 / C3, self.varshape)
else:
self.p.value = bm.broadcast_to(C2 / C3, (batch_size,) + self.varshape)
assert self.p.shape[0] == batch_size

+ 0
- 148
brainpy/dyn/channels/K_channels.py View File

@@ -1,148 +0,0 @@
# -*- coding: utf-8 -*-

from typing import Union, Callable

import brainpy.math as bm
from brainpy.dyn.base import ConNeuGroup
from brainpy.initialize import Initializer, init_param
from brainpy.integrators import odeint
from brainpy.types import Shape, Tensor
from .base import IonChannel

__all__ = [
'PotassiumChannel',
'IK_DR',
'IK2',
]


class PotassiumChannel(IonChannel):
"""Base class for potassium channel."""

'''The type of the master object.'''
master_cls = ConNeuGroup


class IK_DR(PotassiumChannel):
r"""The delayed rectifier potassium channel current.

The potassium current model is adopted from (Bazhenov, et, al. 2002) [1]_.
It's dynamics is given by:

.. math::

\begin{aligned}
I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\
\frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\
\alpha_{p} &=\frac{0.032\left(V-V_{sh}-15\right)}{1-\exp \left(-\left(V-V_{sh}-15\right) / 5\right)} \\
\beta_p &= 0.5 \exp \left(-\left(V-V_{sh}-10\right) / 40\right)
\end{aligned}

where :math:`\phi` is a temperature-dependent factor, which is given by
:math:`\phi=3^{\frac{T-36}{10}}` (:math:`T` is the temperature in Celsius).


Parameters
----------
size: int, sequence of int
The object size.
g_max : float, JaxArray, ndarray, Initializer, Callable
The maximal conductance density (:math:`mS/cm^2`).
E : float, JaxArray, ndarray, Initializer, Callable
The reversal potential (mV).
T : float, JaxArray, ndarray, Initializer, Callable
The temperature (Celsius, :math:`^{\circ}C`).
V_sh : float, JaxArray, ndarray, Initializer, Callable
The shift of the membrane potential to spike.
method: str
The numerical integration method.
name: str
The object name.

References
----------
.. [1] Bazhenov, Maxim, et al. "Model of thalamocortical slow-wave sleep oscillations
and transitions to activated states." Journal of neuroscience 22.19 (2002): 8691-8704.

"""

def __init__(
self,
size: Shape,
E: Union[float, Tensor, Initializer, Callable] = -90.,
g_max: Union[float, Tensor, Initializer, Callable] = 10.,
T: Union[float, Tensor, Initializer, Callable] = 36.,
T_base: Union[float, Tensor, Initializer, Callable] = 3.,
V_sh: Union[float, Tensor, Initializer, Callable] = -50.,
method: str = 'exp_auto',
name: str = None
):
super(IK_DR, self).__init__(size, name=name)

# parameters
self.T = init_param(T, self.num, allow_none=False)
self.T_base = init_param(T_base, self.num, allow_none=False)
self.E = init_param(E, self.num, allow_none=False)
self.g_max = init_param(g_max, self.num, allow_none=False)
self.V_sh = init_param(V_sh, self.num, allow_none=False)
self.phi = self.T_base ** ((self.T - 36) / 10)

# variables
self.p = bm.Variable(bm.zeros(self.num))

# function
self.integral = odeint(self.derivative, method=method)

def derivative(self, p, t, V):
alpha = 0.032 * (V - self.V_sh - 15.) / (1. - bm.exp(-(V - self.V_sh - 15.) / 5.))
beta = 0.5 * bm.exp(-(V - self.V_sh - 10.) / 40.)
return self.phi * (alpha * (1. - p) - beta * p)

def update(self, t, dt, V):
self.p.value = self.integral(self.p, t, V, dt=dt)

def current(self, V):
return self.g_max * self.p ** 4 * (self.E - V)

def reset(self, V):
alpha = 0.032 * (V - self.V_sh - 15.) / (1. - bm.exp(-(V - self.V_sh - 15.) / 5.))
beta = 0.5 * bm.exp(-(V - self.V_sh - 10.) / 40.)
self.p.value = alpha / (alpha + beta)


class IK2(PotassiumChannel):
def __init__(
self,
size: Shape,
E: Union[float, Tensor, Initializer, Callable] = -90.,
g_max: Union[float, Tensor, Initializer, Callable] = 10.,
method='exp_auto',
name=None
):
super(IK2, self).__init__(size, name=name)

# parameters
self.E = init_param(E, self.num, allow_none=False)
self.g_max = init_param(g_max, self.num, allow_none=False)

# variables
self.n = bm.Variable(bm.zeros(self.num))

# function
self.integral = odeint(self.derivative, method=method)

def derivative(self, n, t, V):
alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
beta = 0.125 * bm.exp(-(V + 65) / 80)
return alpha * (1 - n) - beta * n

def update(self, t, dt, V):
self.n.value = self.integral(self.n, t, V, dt)

def current(self, V):
return self.g_max * self.n ** 4 * (self.E - V)

def reset(self, V):
alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
beta = 0.125 * bm.exp(-(V + 65) / 80)
self.n.value = alpha / (alpha + beta)

+ 371
- 0
brainpy/dyn/channels/Na.py View File

@@ -0,0 +1,371 @@
# -*- coding: utf-8 -*-

"""
This module implements voltage-dependent sodium channels.

"""

from typing import Union, Callable

import brainpy.math as bm
from brainpy.initialize import Initializer, parameter, variable
from brainpy.integrators import odeint, JointEq
from brainpy.types import Array, Shape
from brainpy.modes import Mode, BatchingMode, normal
from .base import SodiumChannel

__all__ = [
'INa_p3q_markov',
'INa_Ba2002',
'INa_TM1991',
'INa_HH1952',
]


class INa_p3q_markov(SodiumChannel):
r"""The sodium current model of :math:`p^3q` current which described with first-order Markov chain.

The general model can be used to model the dynamics with:

.. math::

\begin{aligned}
I_{\mathrm{Na}} &= g_{\mathrm{max}} * p^3 * q \\
\frac{dp}{dt} &= \phi ( \alpha_p (1-p) - \beta_p p) \\
\frac{dq}{dt} & = \phi ( \alpha_q (1-h) - \beta_q h) \\
\end{aligned}

where :math:`\phi` is a temperature-dependent factor.

Parameters
----------
g_max : float, Array, Callable, Initializer
The maximal conductance density (:math:`mS/cm^2`).
E : float, Array, Callable, Initializer
The reversal potential (mV).
phi : float, Array, Callable, Initializer
The temperature-dependent factor.
method: str
The numerical method
name: str
The name of the object.

"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
E: Union[int, float, Array, Initializer, Callable] = 50.,
g_max: Union[int, float, Array, Initializer, Callable] = 90.,
phi: Union[int, float, Array, Initializer, Callable] = 1.,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
super(INa_p3q_markov, self).__init__(size=size,
keep_size=keep_size,
name=name,
mode=mode)

# parameters
self.E = parameter(E, self.varshape, allow_none=False)
self.phi = parameter(phi, self.varshape, allow_none=False)
self.g_max = parameter(g_max, self.varshape, allow_none=False)

# variables
self.p = variable(bm.zeros, mode, self.varshape)
self.q = variable(bm.zeros, mode, self.varshape)

# function
self.integral = odeint(JointEq([self.dp, self.dq]), method=method)

def reset_state(self, V, batch_size=None):
alpha = self.f_p_alpha(V)
beta = self.f_p_beta(V)
self.p.value = alpha / (alpha + beta)
alpha = self.f_q_alpha(V)
beta = self.f_q_beta(V)
self.q.value = alpha / (alpha + beta)
if batch_size is not None:
assert self.p.shape[0] == batch_size
assert self.q.shape[0] == batch_size

def dp(self, p, t, V):
return self.phi * (self.f_p_alpha(V) * (1. - p) - self.f_p_beta(V) * p)

def dq(self, q, t, V):
return self.phi * (self.f_q_alpha(V) * (1. - q) - self.f_q_beta(V) * q)

def update(self, tdi, V):
t, dt = tdi['t'], tdi['dt']
p, q = self.integral(self.p, self.q, t, V, dt)
self.p.value, self.q.value = p, q

def current(self, V):
return self.g_max * self.p ** 3 * self.q * (self.E - V)

def f_p_alpha(self, V):
raise NotImplementedError

def f_p_beta(self, V):
raise NotImplementedError

def f_q_alpha(self, V):
raise NotImplementedError

def f_q_beta(self, V):
raise NotImplementedError


class INa_Ba2002(INa_p3q_markov):
r"""The sodium current model.

The sodium current model is adopted from (Bazhenov, et, al. 2002) [1]_.
It's dynamics is given by:

.. math::

\begin{aligned}
I_{\mathrm{Na}} &= g_{\mathrm{max}} * p^3 * q \\
\frac{dp}{dt} &= \phi ( \alpha_p (1-p) - \beta_p p) \\
\alpha_{p} &=\frac{0.32\left(V-V_{sh}-13\right)}{1-\exp \left(-\left(V-V_{sh}-13\right) / 4\right)} \\
\beta_{p} &=\frac{-0.28\left(V-V_{sh}-40\right)}{1-\exp \left(\left(V-V_{sh}-40\right) / 5\right)} \\
\frac{dq}{dt} & = \phi ( \alpha_q (1-h) - \beta_q h) \\
\alpha_q &=0.128 \exp \left(-\left(V-V_{sh}-17\right) / 18\right) \\
\beta_q &= \frac{4}{1+\exp \left(-\left(V-V_{sh}-40\right) / 5\right)}
\end{aligned}

where :math:`\phi` is a temperature-dependent factor, which is given by
:math:`\phi=3^{\frac{T-36}{10}}` (:math:`T` is the temperature in Celsius).

Parameters
----------
g_max : float, Array, Callable, Initializer
The maximal conductance density (:math:`mS/cm^2`).
E : float, Array, Callable, Initializer
The reversal potential (mV).
T : float, Array
The temperature (Celsius, :math:`^{\circ}C`).
V_sh : float, Array, Callable, Initializer
The shift of the membrane potential to spike.

References
----------

.. [1] Bazhenov, Maxim, et al. "Model of thalamocortical slow-wave sleep oscillations
and transitions to activated states." Journal of neuroscience 22.19 (2002): 8691-8704.

See Also
--------
INa_TM1991
"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
T: Union[int, float, Array] = 36.,
E: Union[int, float, Array, Initializer, Callable] = 50.,
g_max: Union[int, float, Array, Initializer, Callable] = 90.,
V_sh: Union[int, float, Array, Initializer, Callable] = -50.,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
super(INa_Ba2002, self).__init__(size,
keep_size=keep_size,
name=name,
method=method,
phi=3 ** ((T - 36) / 10),
g_max=g_max,
E=E,
mode=mode)
self.T = parameter(T, self.varshape, allow_none=False)
self.V_sh = parameter(V_sh, self.varshape, allow_none=False)

def f_p_alpha(self, V):
temp = V - self.V_sh - 13.
return 0.32 * temp / (1. - bm.exp(-temp / 4.))

def f_p_beta(self, V):
temp = V - self.V_sh - 40.
return -0.28 * temp / (1. - bm.exp(temp / 5.))

def f_q_alpha(self, V):
return 0.128 * bm.exp(-(V - self.V_sh - 17.) / 18.)

def f_q_beta(self, V):
return 4. / (1. + bm.exp(-(V - self.V_sh - 40.) / 5.))


class INa_TM1991(INa_p3q_markov):
r"""The sodium current model described by (Traub and Miles, 1991) [1]_.

The dynamics of this sodium current model is given by:

.. math::

\begin{split}
\begin{aligned}
I_{\mathrm{Na}} &= g_{\mathrm{max}} m^3 h \\
\frac {dm} {dt} &= \phi(\alpha_m (1-x) - \beta_m) \\
&\alpha_m(V) = 0.32 \frac{(13 - V + V_{sh})}{\exp((13 - V +V_{sh}) / 4) - 1.} \\
&\beta_m(V) = 0.28 \frac{(V - V_{sh} - 40)}{(\exp((V - V_{sh} - 40) / 5) - 1)} \\
\frac {dh} {dt} &= \phi(\alpha_h (1-x) - \beta_h) \\
&\alpha_h(V) = 0.128 * \exp((17 - V + V_{sh}) / 18) \\
&\beta_h(V) = 4. / (1 + \exp(-(V - V_{sh} - 40) / 5)) \\
\end{aligned}
\end{split}

where :math:`V_{sh}` is the membrane shift (default -63 mV), and
:math:`\phi` is the temperature-dependent factor (default 1.).

Parameters
----------
size: int, tuple of int
The size of the simulation target.
keep_size: bool
Keep size or flatten the size?
method: str
The numerical method
name: str
The name of the object.
g_max : float, Array, Callable, Initializer
The maximal conductance density (:math:`mS/cm^2`).
E : float, Array, Callable, Initializer
The reversal potential (mV).
V_sh: float, Array, Callable, Initializer
The membrane shift.

References
----------
.. [1] Traub, Roger D., and Richard Miles. Neuronal networks of the hippocampus.
Vol. 777. Cambridge University Press, 1991.

See Also
--------
INa_Ba2002
"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
E: Union[int, float, Array, Initializer, Callable] = 50.,
g_max: Union[int, float, Array, Initializer, Callable] = 120.,
phi: Union[int, float, Array, Initializer, Callable] = 1.,
V_sh: Union[int, float, Array, Initializer, Callable] = -63.,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
super(INa_TM1991, self).__init__(size,
keep_size=keep_size,
name=name,
method=method,
E=E,
phi=phi,
g_max=g_max,
mode=mode)
self.V_sh = parameter(V_sh, self.varshape, allow_none=False)

def f_p_alpha(self, V):
temp = 13 - V + self.V_sh
return 0.32 * temp / (bm.exp(temp / 4) - 1.)

def f_p_beta(self, V):
temp = V - self.V_sh - 40
return 0.28 * temp / (bm.exp(temp / 5) - 1)

def f_q_alpha(self, V):
return 0.128 * bm.exp((17 - V + self.V_sh) / 18)

def f_q_beta(self, V):
return 4. / (1 + bm.exp(-(V - self.V_sh - 40) / 5))


class INa_HH1952(INa_p3q_markov):
r"""The sodium current model described by Hodgkin–Huxley model [1]_.

The dynamics of this sodium current model is given by:

.. math::

\begin{split}
\begin{aligned}
I_{\mathrm{Na}} &= g_{\mathrm{max}} m^3 h \\
\frac {dm} {dt} &= \phi (\alpha_m (1-x) - \beta_m) \\
&\alpha_m(V) = \frac {0.1(V-V_{sh}-5)}{1-\exp(\frac{-(V -V_{sh} -5)} {10})} \\
&\beta_m(V) = 4.0 \exp(\frac{-(V -V_{sh}+ 20)} {18}) \\
\frac {dh} {dt} &= \phi (\alpha_h (1-x) - \beta_h) \\
&\alpha_h(V) = 0.07 \exp(\frac{-(V-V_{sh}+20)}{20}) \\
&\beta_h(V) = \frac 1 {1 + \exp(\frac{-(V -V_{sh}-10)} {10})} \\
\end{aligned}
\end{split}

where :math:`V_{sh}` is the membrane shift (default -45 mV), and
:math:`\phi` is the temperature-dependent factor (default 1.).

Parameters
----------
size: int, tuple of int
The size of the simulation target.
keep_size: bool
Keep size or flatten the size?
method: str
The numerical method
name: str
The name of the object.
g_max : float, Array, Callable, Initializer
The maximal conductance density (:math:`mS/cm^2`).
E : float, Array, Callable, Initializer
The reversal potential (mV).
V_sh: float, Array, Callable, Initializer
The membrane shift.

References
----------
.. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of
membrane current and its application to conduction and excitation in
nerve." The Journal of physiology 117.4 (1952): 500.

See Also
--------
IK_HH1952
"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
E: Union[int, float, Array, Initializer, Callable] = 50.,
g_max: Union[int, float, Array, Initializer, Callable] = 120.,
phi: Union[int, float, Array, Initializer, Callable] = 1.,
V_sh: Union[int, float, Array, Initializer, Callable] = -45.,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
super(INa_HH1952, self).__init__(size,
keep_size=keep_size,
name=name,
method=method,
E=E,
phi=phi,
g_max=g_max,
mode=mode)
self.V_sh = parameter(V_sh, self.varshape, allow_none=False)

def f_p_alpha(self, V):
temp = V - self.V_sh - 5
return 0.1 * temp / (1 - bm.exp(-temp / 10))

def f_p_beta(self, V):
return 4.0 * bm.exp(-(V - self.V_sh + 20) / 18)

def f_q_alpha(self, V):
return 0.07 * bm.exp(-(V - self.V_sh + 20) / 20.)

def f_q_beta(self, V):
return 1 / (1 + bm.exp(-(V - self.V_sh - 10) / 10))

+ 0
- 165
brainpy/dyn/channels/Na_channels.py View File

@@ -1,165 +0,0 @@
# -*- coding: utf-8 -*-

from typing import Union, Callable

import brainpy.math as bm
from brainpy.dyn.base import ConNeuGroup
from brainpy.initialize import Initializer, init_param
from brainpy.integrators import odeint, JointEq
from brainpy.types import Tensor, Shape
from .base import IonChannel

__all__ = [
'INa',
'INa_v2',
]


class SodiumChannel(IonChannel):
"""Base class for sodium channel."""
master_cls = ConNeuGroup


class INa(SodiumChannel):
r"""The sodium current model.

The sodium current model is adopted from (Bazhenov, et, al. 2002) [1]_.
It's dynamics is given by:

.. math::

\begin{aligned}
I_{\mathrm{Na}} &= g_{\mathrm{max}} * p^3 * q \\
\frac{dp}{dt} &= \phi ( \alpha_p (1-p) - \beta_p p) \\
\alpha_{p} &=\frac{0.32\left(V-V_{sh}-13\right)}{1-\exp \left(-\left(V-V_{sh}-13\right) / 4\right)} \\
\beta_{p} &=\frac{-0.28\left(V-V_{sh}-40\right)}{1-\exp \left(\left(V-V_{sh}-40\right) / 5\right)} \\
\frac{dq}{dt} & = \phi ( \alpha_q (1-h) - \beta_q h) \\
\alpha_q &=0.128 \exp \left(-\left(V-V_{sh}-17\right) / 18\right) \\
\beta_q &= \frac{4}{1+\exp \left(-\left(V-V_{sh}-40\right) / 5\right)}
\end{aligned}

where :math:`\phi` is a temperature-dependent factor, which is given by
:math:`\phi=3^{\frac{T-36}{10}}` (:math:`T` is the temperature in Celsius).

**Model Examples**

- `(Brette, et, al., 2007) COBAHH <../../examples/ei_nets/Brette_2007_COBAHH.ipynb>`_

Parameters
----------
g_max : float
The maximal conductance density (:math:`mS/cm^2`).
E : float
The reversal potential (mV).
T : float
The temperature (Celsius, :math:`^{\circ}C`).
V_sh : float
The shift of the membrane potential to spike.

References
----------

.. [1] Bazhenov, Maxim, et al. "Model of thalamocortical slow-wave sleep oscillations
and transitions to activated states." Journal of neuroscience 22.19 (2002): 8691-8704.

"""

def __init__(
self,
size: Shape,
E: Union[int, float, Tensor, Initializer, Callable] = 50.,
g_max: Union[int, float, Tensor, Initializer, Callable] = 90.,
T: Union[int, float, Tensor, Initializer, Callable] = 36.,
V_sh: Union[int, float, Tensor, Initializer, Callable] = -50.,
method: str = 'exp_auto',
name: str = None
):
super(INa, self).__init__(size, name=name)

# parameters
self.T = init_param(T, self.num, allow_none=False)
self.E = init_param(E, self.num, allow_none=False)
self.V_sh = init_param(V_sh, self.num, allow_none=False)
self.g_max = init_param(g_max, self.num, allow_none=False)
self.phi = 3 ** ((self.T - 36) / 10)

# variables
self.p = bm.Variable(bm.zeros(self.num))
self.q = bm.Variable(bm.zeros(self.num))

# function
self.integral = odeint(JointEq([self.dp, self.dq]), method=method)

def reset(self, V):
alpha = 0.32 * (V - self.V_sh - 13.) / (1. - bm.exp(-(V - self.V_sh - 13.) / 4.))
beta = -0.28 * (V - self.V_sh - 40.) / (1. - bm.exp((V - self.V_sh - 40.) / 5.))
self.p.value = alpha / (alpha + beta)
alpha = 0.128 * bm.exp(-(V - self.V_sh - 17.) / 18.)
beta = 4. / (1. + bm.exp(-(V - self.V_sh - 40.) / 5.))
self.q.value = alpha / (alpha + beta)

def dp(self, p, t, V):
alpha_p = 0.32 * (V - self.V_sh - 13.) / (1. - bm.exp(-(V - self.V_sh - 13.) / 4.))
beta_p = -0.28 * (V - self.V_sh - 40.) / (1. - bm.exp((V - self.V_sh - 40.) / 5.))
return self.phi * (alpha_p * (1. - p) - beta_p * p)

def dq(self, q, t, V):
alpha_q = 0.128 * bm.exp(-(V - self.V_sh - 17.) / 18.)
beta_q = 4. / (1. + bm.exp(-(V - self.V_sh - 40.) / 5.))
return self.phi * (alpha_q * (1. - q) - beta_q * q)

def update(self, t, dt, V):
p, q = self.integral(self.p, self.q, t, V, dt)
self.p.value, self.q.value = p, q

def current(self, V):
g = self.g_max * self.p ** 3 * self.q
return g * (self.E - V)


class INa_v2(SodiumChannel):
def __init__(
self,
size: Shape,
E: Union[int, float, Tensor, Initializer, Callable] = 50.,
g_max: Union[int, float, Tensor, Initializer, Callable] = 120.,
method: str = 'exp_auto',
name: str = None
):
super(INa_v2, self).__init__(size, name=name)

# parameters
self.E = init_param(E, self.num, allow_none=False)
self.g_max = init_param(g_max, self.num, allow_none=False)

# variables
self.m = bm.Variable(bm.zeros(self.num))
self.h = bm.Variable(bm.zeros(self.num))

# function
self.integral = odeint(JointEq([self.dm, self.dh]), method=method)

def dm(self, m, t, V):
alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
beta = 4.0 * bm.exp(-(V + 65) / 18)
return alpha * (1 - m) - beta * m

def dh(self, h, t, V):
alpha = 0.07 * bm.exp(-(V + 65) / 20.)
beta = 1 / (1 + bm.exp(-(V + 35) / 10))
return alpha * (1 - h) - beta * h

def update(self, t, dt, V):
self.m.value, self.h.value = self.integral(self.m, self.h, t, V, dt)

def current(self, V):
g = self.g_max * self.m ** 3 * self.h
return g * (self.E - V)

def reset(self, V):
alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
beta = 4.0 * bm.exp(-(V + 65) / 18)
self.m.value = alpha / (alpha + beta)
alpha = 0.07 * bm.exp(-(V + 65) / 20.)
beta = 1 / (1 + bm.exp(-(V + 35) / 10))
self.h.value = alpha / (alpha + beta)

+ 13
- 14
brainpy/dyn/channels/__init__.py View File

@@ -1,22 +1,21 @@
# -*- coding: utf-8 -*-


from . import K_channels, Na_channels, leaky_channels
from . import base, Ca_channels, Ih_channels
from . import base, Ca, IH, K, Na, KCa, leaky

__all__ = []
__all__ += base.__all__
__all__ += Ca_channels.__all__
__all__ += Ih_channels.__all__
__all__ += Ih_channels.__all__
__all__ += K_channels.__all__
__all__ += leaky_channels.__all__
__all__ += Na_channels.__all__

__all__ += K.__all__
__all__ += Na.__all__
__all__ += Ca.__all__
__all__ += IH.__all__
__all__ += KCa.__all__
__all__ += leaky.__all__

from .base import *
from .Ca_channels import *
from .Ih_channels import *
from .K_channels import *
from .Na_channels import *
from .leaky_channels import *
from .K import *
from .Na import *
from .IH import *
from .Ca import *
from .KCa import *
from .leaky import *

+ 121
- 7
brainpy/dyn/channels/base.py View File

@@ -1,9 +1,20 @@
# -*- coding: utf-8 -*-

from brainpy.dyn.base import Channel, ConNeuGroup
from typing import Union

import brainpy.math as bm
from brainpy.dyn.base import Container, CondNeuGroup, Channel, check_master
from brainpy.types import Shape
from brainpy.modes import normal, Mode

__all__ = [
'Ion', 'IonChannel',

# ions
'Calcium',

# ion channels
'IhChannel', 'CalciumChannel', 'SodiumChannel', 'PotassiumChannel', 'LeakyChannel',
]


@@ -11,12 +22,15 @@ class Ion(Channel):
"""Base class for ions."""

'''The type of the master object.'''
master_cls = ConNeuGroup
master_type = CondNeuGroup

def update(self, t, dt, V):
def update(self, tdi, V):
raise NotImplementedError('Must be implemented by the subclass.')

def reset(self, V):
def reset(self, V, batch_size=None):
self.reset_state(V, batch_size)

def reset_state(self, V, batch_size=None):
raise NotImplementedError('Must be implemented by the subclass.')

def current(self, V):
@@ -30,16 +44,116 @@ class IonChannel(Channel):
"""Base class for ion channels."""

'''The type of the master object.'''
master_cls = ConNeuGroup
master_type = CondNeuGroup

def update(self, t, dt, V):
def update(self, tdi, V):
raise NotImplementedError('Must be implemented by the subclass.')

def current(self, V):
raise NotImplementedError('Must be implemented by the subclass.')

def reset(self, V):
def reset(self, V, batch_size=None):
self.reset_state(V, batch_size)

def reset_state(self, V, batch_size=None):
raise NotImplementedError('Must be implemented by the subclass.')

def __repr__(self):
return f'{self.__class__.__name__}(size={self.size})'


class Calcium(Ion, Container):
"""The base calcium dynamics.

Parameters
----------
size: int, sequence of int
The size of the simulation target.
method: str
The numerical integration method.
name: str
The name of the object.
**channels
The calcium dependent channels.
"""

'''The type of the master object.'''
master_type = CondNeuGroup

"""Reversal potential."""
E: Union[float, bm.Variable, bm.JaxArray]

"""Calcium concentration."""
C: Union[float, bm.Variable, bm.JaxArray]

def __init__(
self,
size: Shape,
keep_size: bool = False,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
**channels
):
Ion.__init__(self, size, keep_size=keep_size, mode=mode)
Container.__init__(self, name=name, mode=mode, **channels)
self.method = method

def current(self, V, C_Ca=None, E_Ca=None):
C_Ca = self.C if (C_Ca is None) else C_Ca
E_Ca = self.E if (E_Ca is None) else E_Ca
nodes = tuple(self.nodes(level=1, include_self=False).unique().subset(Channel).values())
check_master(type(self), *nodes)

if len(nodes) == 0:
return 0.
else:
current = nodes[0].current(V, C_Ca, E_Ca)
for node in nodes[1:]:
current += node.current(V, C_Ca, E_Ca)
return current

def register_implicit_nodes(self, *channels, **named_channels):
check_master(type(self), *channels, **named_channels)
super(Calcium, self).register_implicit_nodes(*channels, **named_channels)


class CalciumChannel(IonChannel):
"""Base class for Calcium ion channels."""

'''The type of the master object.'''
master_type = Calcium

def update(self, tdi, V, C_Ca, E_Ca):
raise NotImplementedError

def current(self, V, C_Ca, E_Ca):
raise NotImplementedError

def reset(self, V, C_Ca, E_Ca, batch_size=None):
self.reset_state(V, C_Ca, E_Ca, batch_size)

def reset_state(self, V, C_Ca, E_Ca, batch_size=None):
raise NotImplementedError('Must be implemented by the subclass.')


class IhChannel(IonChannel):
"""Base class for Ih channel models."""
master_type = CondNeuGroup


class PotassiumChannel(IonChannel):
"""Base class for potassium channel."""

'''The type of the master object.'''
master_type = CondNeuGroup


class LeakyChannel(IonChannel):
"""Base class for leaky channel."""
master_type = CondNeuGroup


class SodiumChannel(IonChannel):
"""Base class for sodium channel."""
master_type = CondNeuGroup

+ 90
- 0
brainpy/dyn/channels/leaky.py View File

@@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-

"""
This module implements leakage channels.

"""

from typing import Union, Callable

from brainpy.initialize import Initializer, parameter
from brainpy.types import Array, Shape
from brainpy.modes import Mode, BatchingMode, normal

from .base import LeakyChannel

__all__ = [
'IL',
'IKL',
]


class IL(LeakyChannel):
"""The leakage channel current.

Parameters
----------
g_max : float
The leakage conductance.
E : float
The reversal potential.
"""

def __init__(
self,
size,
keep_size: bool = False,
g_max: Union[int, float, Array, Initializer, Callable] = 0.1,
E: Union[int, float, Array, Initializer, Callable] = -70.,
method: str = None,
name: str = None,
mode: Mode = normal,
):
super(IL, self).__init__(size,
keep_size=keep_size,
name=name,
mode=mode)

self.E = parameter(E, self.varshape, allow_none=False)
self.g_max = parameter(g_max, self.varshape, allow_none=False)
self.method = method

def reset_state(self, V, batch_size=None):
pass

def update(self, tdi, V):
pass

def current(self, V):
return self.g_max * (self.E - V)


class IKL(IL):
"""The potassium leak channel current.

Parameters
----------
g_max : float
The potassium leakage conductance which is modulated by both
acetylcholine and norepinephrine.
E : float
The reversal potential.
"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
g_max: Union[int, float, Array, Initializer, Callable] = 0.005,
E: Union[int, float, Array, Initializer, Callable] = -90.,
method: str = None,
name: str = None,
mode: Mode = normal,
):
super(IKL, self).__init__(size=size,
keep_size=keep_size,
g_max=g_max,
E=E,
method=method,
name=name,
mode=mode)

+ 0
- 75
brainpy/dyn/channels/leaky_channels.py View File

@@ -1,75 +0,0 @@
# -*- coding: utf-8 -*-

from brainpy.types import Shape

from brainpy.dyn.base import ConNeuGroup
from .base import IonChannel

__all__ = [
'LeakyChannel',
'IL',
'IKL',
]


class LeakyChannel(IonChannel):
"""Base class for leaky channel."""
master_cls = ConNeuGroup


class IL(LeakyChannel):
"""The leakage channel current.

Parameters
----------
g_max : float
The leakage conductance.
E : float
The reversal potential.
"""

def __init__(
self,
size,
g_max=0.1,
E=-70.,
method: str = None,
name: str = None,
):
super(IL, self).__init__(size, name=name)

self.E = E
self.g_max = g_max
self.method = method

def reset(self, V):
pass

def update(self, t, dt, V):
pass

def current(self, V):
return self.g_max * (self.E - V)


class IKL(IL):
"""The potassium leak channel current.

Parameters
----------
g_max : float
The potassium leakage conductance which is modulated by both
acetylcholine and norepinephrine.
E : float
The reversal potential.
"""

def __init__(
self,
size: Shape,
g_max=0.005,
E=-90.,
method=None,
name=None,
):
super(IKL, self).__init__(size=size, g_max=g_max, E=E, method=method, name=name)

+ 10
- 0
brainpy/dyn/layers/__init__.py View File

@@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-

from .dropout import *
from .linear import *
from .nvar import *
from .reservoir import *
from .rnncells import *
from .conv import *
from .normalization import *
from .pooling import *

+ 220
- 0
brainpy/dyn/layers/conv.py View File

@@ -0,0 +1,220 @@
# -*- coding: utf-8 -*-


import jax.lax

import brainpy.math as bm
from brainpy.dyn.base import DynamicalSystem
from brainpy.initialize import XavierNormal, ZeroInit, parameter
from brainpy.modes import Mode, TrainingMode, NormalMode, training, check

__all__ = [
'GeneralConv',
'Conv1D',
'Conv2D',
'Conv3D'
]


def _check_tuple(v):
if isinstance(v, (tuple, list)):
return tuple(v)
elif isinstance(v, int):
return (v, v)
else:
raise ValueError


def _conv_dimension_numbers(input_shape):
"""Computes the dimension numbers based on the input shape."""
ndim = len(input_shape)
lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1))
rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2))
out_spec = lhs_spec
return jax.lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)


class GeneralConv(DynamicalSystem):
"""Applies a convolution to the inputs.

Parameters
----------
in_channels: integer
number of input channels.
out_channels: integer
number of output channels.
kernel_size: sequence[int]
shape of the convolutional kernel. For 1D convolution,
the kernel size can be passed as an integer. For all other cases, it must
be a sequence of integers.
strides: sequence[int]
an integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
padding: str, sequence[int]
either the string `'SAME'`, the string `'VALID'`, the string
`'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low,
high)` integer pairs that give the padding to apply before and after each
spatial dimension. A single int is interpeted as applying the same padding
in all dims and passign a single int in a sequence causes the same padding
to be used on both sides.
input_dilation: integer, sequence[int]
an integer or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `inputs`
(default: 1). Convolution with input dilation `d` is equivalent to
transposed convolution with stride `d`.
kernel_dilation: integer, sequence[int]
an integer or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of the convolution
kernel (default: 1). Convolution with kernel dilation
is also known as 'atrous convolution'.
groups: integer, default 1.
If specified divides the input
features into groups.
w_init: brainpy.init.Initializer
initializer for the convolutional kernel.
b_init: brainpy.init.Initializer
initializer for the bias.
"""

def __init__(
self,
in_channels,
out_channels,
kernel_size,
strides=None,
padding='SAME',
input_dilation=None,
kernel_dilation=None,
groups=1,
w_init=XavierNormal(),
b_init=ZeroInit(),
mode: Mode = training,
name: str = None,
):
super(GeneralConv, self).__init__(name=name, mode=mode)

self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.strides = strides
self.padding = padding
self.input_dilation = input_dilation
self.kernel_dilation = kernel_dilation
self.groups = groups
self.w_init = w_init
self.b_init = b_init
self.dimension_numbers = None

if isinstance(padding, str):
assert padding in ['SAME', 'VALID']
elif isinstance(padding, tuple):
for k in padding:
assert isinstance(k, int)
else:
raise ValueError

assert out_channels % self.groups == 0, '"nout" should be divisible by groups'

assert self.in_channels % self.groups == 0, '"nin" should be divisible by groups'
kernel_shape = _check_tuple(self.kernel_size) + (self.in_channels // self.groups, self.out_channels)
self.w = parameter(self.w_init, kernel_shape)
self.b = parameter(self.b_init, (1,) * len(self.kernel_size) + (self.out_channels,))
if isinstance(self.mode, TrainingMode):
self.w = bm.TrainVar(self.w)
self.b = bm.TrainVar(self.b)

def _check_input_dim(self, x):
pass

def update(self, sha, x):
self._check_input_dim(x)
if self.strides is None:
self.strides = (1,) * (len(x.shape) - 2)
y = jax.lax.conv_general_dilated(lhs=x.value if isinstance(x, bm.JaxArray) else x,
rhs=self.w.value,
window_strides=self.strides,
padding=self.padding,
lhs_dilation=self.input_dilation,
rhs_dilation=self.kernel_dilation,
feature_group_count=self.groups,
dimension_numbers=self.dimension_numbers)
if self.b is None:
return y
return y + self.b.value

def reset_state(self, batch_size=None):
pass


class Conv1D(GeneralConv):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
**kwargs
):
super(Conv1D, self).__init__(in_channels, out_channels, kernel_size, **kwargs)

self.dimension_numbers = ('NWC', 'WIO', 'NWC')

def _check_input_dim(self, x):
ndim = len(x.shape)
if ndim != 3:
raise ValueError(
"expected 3D input (got {}D input)".format(ndim)
)
if self.in_channels != x.shape[-1]:
raise ValueError(
f"input channels={x.shape[-1]} needs to have the same size as in_channels={self.in_channels}."
)
assert len(self.kernel_size) == 1, "expected 1D kernel size (got {}D input)".format(self.kernel_size)


class Conv2D(GeneralConv):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
**kwargs
):
super(Conv2D, self).__init__(in_channels, out_channels, kernel_size, **kwargs)

self.dimension_numbers = ('NHWC', 'HWIO', 'NHWC')

def _check_input_dim(self, x):
ndim = len(x.shape)
if ndim != 4:
raise ValueError(
"expected 4D input (got {}D input)".format(ndim)
)
if self.in_channels != x.shape[-1]:
raise ValueError(
f"input channels={x.shape[-1]} needs to have the same size as in_channels={self.in_channels}."
)
assert len(self.kernel_size) == 2, "expected 2D kernel size (got {}D input)".format(self.kernel_size)


class Conv3D(GeneralConv):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
**kwargs
):
super(Conv3D, self).__init__(in_channels, out_channels, kernel_size, **kwargs)

self.dimension_numbers = ('NHWDC', 'HWDIO', 'NHWDC')

def _check_input_dim(self, x):
ndim = len(x.shape)
if ndim != 5:
raise ValueError(
"expected 5D input (got {}D input)".format(ndim)
)
if self.in_channels != x.shape[-1]:
raise ValueError(
f"input channels={x.shape[-1]} needs to have the same size as in_channels={self.in_channels}."
)
assert len(self.kernel_size) == 3, "expected 3D kernel size (got {}D input)".format(self.kernel_size)

brainpy/nn/nodes/ANN/dropout.py → brainpy/dyn/layers/dropout.py View File

@@ -1,14 +1,15 @@
# -*- coding: utf-8 -*-

import brainpy.math as bm
from brainpy.nn.base import Node
from brainpy.dyn.base import DynamicalSystem
from brainpy.modes import Mode, training

__all__ = [
'Dropout'
]


class Dropout(Node):
class Dropout(DynamicalSystem):
"""A layer that stochastically ignores a subset of inputs each training step.

In training, to compensate for the fraction of input values dropped (`rate`),
@@ -36,17 +37,24 @@ class Dropout(Node):
neural networks from overfitting." The journal of machine learning
research 15.1 (2014): 1929-1958.
"""
def __init__(self, prob, seed=None, **kwargs):
super(Dropout, self).__init__(**kwargs)

def __init__(
self,
prob: float,
seed: int = None,
mode: Mode = training,
name: str = None
):
super(Dropout, self).__init__(mode=mode, name=name)
self.prob = prob
self.rng = bm.random.RandomState(seed=seed)

def init_ff_conn(self):
self.set_output_shape(self.feedforward_shapes)

def forward(self, ff, **shared_kwargs):
if shared_kwargs.get('train', True):
keep_mask = self.rng.bernoulli(self.prob, ff.shape)
return bm.where(keep_mask, ff / self.prob, 0.)
def update(self, sha, x):
if sha.get('fit', True):
keep_mask = self.rng.bernoulli(self.prob, x.shape)
return bm.where(keep_mask, x / self.prob, 0.)
else:
return ff
return x

def reset_state(self, batch_size=None):
pass

+ 190
- 0
brainpy/dyn/layers/linear.py View File

@@ -0,0 +1,190 @@
# -*- coding: utf-8 -*-


from typing import Optional, Callable, Union, Dict

import jax.numpy as jnp

from brainpy import math as bm
from brainpy.dyn.base import DynamicalSystem
from brainpy.errors import MathError
from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter
from brainpy.modes import Mode, TrainingMode, training
from brainpy.tools.checking import check_initializer
from brainpy.types import Array

__all__ = [
'Dense',
]


class Dense(DynamicalSystem):
r"""A linear transformation applied over the last dimension of the input.

Mathematically, this node can be defined as:

.. math::

y = x \cdot W + b

Parameters
----------
num_in: int
The number of the input feature. A positive integer.
num_out: int
The number of the output features. A positive integer.
W_initializer: optional, Initializer
The weight initialization.
b_initializer: optional, Initializer
The bias initialization.
mode: Mode
Enable training this node or not. (default True)
"""

def __init__(
self,
num_in: int,
num_out: int,
W_initializer: Union[Initializer, Callable, Array] = XavierNormal(),
b_initializer: Optional[Union[Initializer, Callable, Array]] = ZeroInit(),
mode: Mode = training,
name: str = None,
):
super(Dense, self).__init__(mode=mode, name=name)

# shape
self.num_in = num_in
self.num_out = num_out
if num_in < 0:
raise ValueError(f'Received an invalid value for `num_out`, expected '
f'a positive integer. Received: num_in={num_in}')
if num_out < 0:
raise ValueError(f'Received an invalid value for `num_out`, expected '
f'a positive integer. Received: num_out={num_out}')

# weight initializer
self.weight_initializer = W_initializer
self.bias_initializer = b_initializer
check_initializer(W_initializer, 'weight_initializer')
check_initializer(b_initializer, 'bias_initializer', allow_none=True)

# parameter initialization
self.W = parameter(self.weight_initializer, (num_in, self.num_out))
self.b = parameter(self.bias_initializer, (self.num_out,))
if isinstance(self.mode, TrainingMode):
self.W = bm.TrainVar(self.W)
self.b = None if (self.b is None) else bm.TrainVar(self.b)

def __repr__(self):
return (f'{self.__class__.__name__}(name={self.name}, '
f'num_in={self.num_in}, '
f'num_out={self.num_out}, '
f'mode={self.mode})')

def reset_state(self, batch_size=None):
pass

def update(self, sha, x):
res = x @ self.W
if self.b is not None:
res += self.b

# online fitting data
if sha.get('fit', False) and self.online_fit_by is not None:
self.fit_record['input'] = x
self.fit_record['output'] = res

# offline fitting data
if sha.get('fit', False) and self.offline_fit_by is not None:
self.fit_record['input'] = x
self.fit_record['output'] = res
return res

def online_init(self):
if self.b is None:
num_input = self.num_in
else:
num_input = self.num_in + 1
self.online_fit_by.initialize(feature_in=num_input, feature_out=self.num_out, identifier=self.name)

def online_fit(self,
target: Array,
fit_record: Dict[str, Array]):
if not isinstance(target, (bm.ndarray, jnp.ndarray)):
raise MathError(f'"target" must be a tensor, but got {type(target)}')
x = fit_record['input']
y = fit_record['output']
if x.ndim != 2:
raise ValueError(f'"ff" must be a 2D tensor with shape of (num_sample, '
f'num_feature), but we got {x.shape}')
if target.ndim != 2:
raise ValueError(f'"target" must be a 2D tensor with shape of (num_sample, '
f'num_feature), but we got {target.shape}')
if x.shape[0] != target.shape[0]:
raise ValueError(f'Batch size of the input and target data should be '
f'the same, while we got {x.shape[0]} != {target.shape[0]}.')
if target.shape[1] != y.shape[1]:
raise MathError(f'The output dimension of output and target data should be '
f'the same, while we got {target.shape[1]} != {y.shape[1]}')

# data
if self.b is not None:
x = bm.concatenate([bm.ones((x.shape[0], 1)), x], axis=-1)

# fitting
dW = self.online_fit_by.call(target=target, input=x, output=y, identifier=self.name)

# assign trained weights
if self.b is None:
self.W += dW
else:
db, dW = bm.split(dW, [1])
self.b += db[0]
self.W += dW

def offline_init(self):
if self.b is None:
num_input = self.num_in + 1
else:
num_input = self.num_in
self.offline_fit_by.initialize(feature_in=num_input, feature_out=self.num_out, identifier=self.name)

def offline_fit(self,
target: Array,
fit_record: Dict[str, Array]):
"""The offline training interface for the Dense node."""
# data checking
if not isinstance(target, (bm.ndarray, jnp.ndarray)):
raise MathError(f'"targets" must be a tensor, but got {type(target)}')
xs = fit_record['input']
ys = fit_record['output']
if xs.ndim != 3:
raise ValueError(f'"ffs" must be a 3D tensor with shape of (num_sample, num_time, '
f'num_feature), but we got {xs.shape}')
if target.ndim != 3:
raise ValueError(f'"targets" must be a 3D tensor with shape of (num_sample, num_time, '
f'num_feature), but we got {target.shape}')
if ys.shape != target.shape:
raise ValueError(f'The shapes of output and target data should be '
f'the same, while we got {ys.shape} != {target.shape}.')
if xs.shape[0] != target.shape[0]:
raise ValueError(f'Batch size of the input and target data should be '
f'the same, while we got {xs.shape[0]} != {target.shape[0]}.')
if xs.shape[1] != target.shape[1]:
raise MathError(f'The time dimension of input and target data should be '
f'the same, while we got {xs.shape[1]} != {target.shape[1]}')

# get input and target training data
if self.b is not None:
xs = bm.concatenate([bm.ones(xs.shape[:2] + (1,)), xs], axis=-1) # (..., 1 + num_ff_input)

# solve weights by offline training methods
weights = self.offline_fit_by(self.name, target, xs, ys)

# assign trained weights
if self.b is None:
self.W.value = weights
else:
bias, Wff = bm.split(weights, [1])
self.W.value = Wff
self.b.value = bias[0]

brainpy/nn/nodes/ANN/normalization.py → brainpy/dyn/layers/normalization.py View File

@@ -4,11 +4,12 @@ from typing import Union

import jax.nn
import jax.numpy as jnp
import jax.lax

import brainpy.math as bm
from brainpy.initialize import ZeroInit, OneInit, Initializer
from brainpy.nn.base import Node
from brainpy.initialize import ZeroInit, OneInit, Initializer, parameter
from brainpy.dyn.base import DynamicalSystem
from brainpy.modes import Mode, TrainingMode, NormalMode, training, check

__all__ = [
'BatchNorm',
@@ -21,7 +22,7 @@ __all__ = [
]


class BatchNorm(Node):
class BatchNorm(DynamicalSystem):
"""Batch Normalization node.
This layer aims to reduce the internal covariant shift of data. It
normalizes a batch of data by fixing the mean and variance of inputs
@@ -55,8 +56,10 @@ class BatchNorm(Node):
use_scale: bool = True,
beta_init: Initializer = ZeroInit(),
gamma_init: Initializer = OneInit(),
mode: Mode = training,
name: str = None,
**kwargs):
super(BatchNorm, self).__init__(**kwargs)
super(BatchNorm, self).__init__(name=name, mode=mode)
self.epsilon = epsilon
self.bias = use_bias
self.scale = use_scale
@@ -64,25 +67,32 @@ class BatchNorm(Node):
self.gamma_init = gamma_init if use_scale else ()
self.axis = (axis,) if jnp.isscalar(axis) else axis

def _check_input_dim(self):
def _check_input_dim(self, x):
pass

def init_ff_conn(self):
self._check_input_dim()

input_shape = tuple(d for i, d in enumerate(self.feedforward_shapes) if i not in self.axis)
self.beta = bm.TrainVar(self.beta_init(input_shape)) if self.bias else None
self.gamma = bm.TrainVar(self.gamma_init(input_shape)) if self.scale else None
self.set_output_shape(self.feedforward_shapes)

def forward(self, ff, **shared_kwargs):
ed = tuple(None if i in self.axis else slice(None) for i in range(jnp.ndim(ff)))
output = bm.normalize(ff, self.axis, epsilon=self.epsilon)
def update(self, sha, x):
self._check_input_dim(x)

input_shape = tuple(d for i, d in enumerate(x.shape) if i not in self.axis)
self.beta = parameter(self.beta_init, input_shape) if self.bias else None
self.gamma = parameter(self.gamma_init, input_shape) if self.scale else None
if isinstance(self.mode, TrainingMode):
self.beta = bm.TrainVar(self.beta)
self.gamma = bm.TrainVar(self.gamma)

ed = tuple(None if i in self.axis else slice(None) for i in range(jnp.ndim(x)))
# output = bm.normalize(x, self.axis, epsilon=self.epsilon)
print(x)
output = jax.nn.standardize(x.value, self.axis, epsilon=self.epsilon)
print(output)
if self.bias and self.scale: return self.gamma[ed] * output + self.beta[ed]
if self.bias: return output + self.beta[ed]
if self.scale: return self.gamma[ed] * output
return output

def reset_state(self, batch_size=None):
pass


class BatchNorm1d(BatchNorm):
"""1-D batch normalization.
@@ -108,8 +118,8 @@ class BatchNorm1d(BatchNorm):
def __init__(self, axis=(0, 1), **kwargs):
super(BatchNorm1d, self).__init__(axis=axis, **kwargs)

def _check_input_dim(self):
ndim = len(self.feedforward_shapes)
def _check_input_dim(self, x):
ndim = len(x.shape)
if ndim != 2 and ndim != 3:
raise ValueError(
"expected 2D or 3D input (got {}D input)".format(ndim)
@@ -142,8 +152,8 @@ class BatchNorm2d(BatchNorm):
def __init__(self, axis=(0, 1, 2), **kwargs):
super(BatchNorm2d, self).__init__(axis=axis, **kwargs)

def _check_input_dim(self):
ndim = len(self.feedforward_shapes)
def _check_input_dim(self, x):
ndim = len(x.shape)
if ndim != 4:
raise ValueError(
"expected 4D input (got {}D input)".format(ndim)
@@ -174,15 +184,15 @@ class BatchNorm3d(BatchNorm):
def __init__(self, axis=(0, 1, 2, 3), **kwargs):
super(BatchNorm3d, self).__init__(axis=axis, **kwargs)

def _check_input_dim(self):
ndim = len(self.feedforward_shapes)
def _check_input_dim(self, x):
ndim = len(x.shape)
if ndim != 5:
raise ValueError(
"expected 5D input (got {}D input)".format(ndim)
)


class LayerNorm(Node):
class LayerNorm(DynamicalSystem):
"""Layer normalization (https://arxiv.org/abs/1607.06450).

This layer normalizes data on each example, independently of the batch. More
@@ -215,8 +225,10 @@ class LayerNorm(Node):
beta_init: Initializer = ZeroInit(),
gamma_init: Initializer = OneInit(),
axis: Union[int, tuple] = None,
mode: Mode = training,
name: str = None,
**kwargs):
super(LayerNorm, self).__init__(**kwargs)
super(LayerNorm, self).__init__(name=name, mode=mode)
self.epsilon = epsilon
self.bias = use_bias
self.scale = use_scale
@@ -224,29 +236,33 @@ class LayerNorm(Node):
self.gamma_init = gamma_init if use_scale else ()
self.axis = (axis,) if jnp.isscalar(axis) else axis

def default_axis(self):
def default_axis(self, x):
# default: the first axis (batch dim) is excluded
return tuple(i for i in range(1, len(self.feedforward_shapes)))
return tuple(i for i in range(1, len(x.shape)))

def init_ff_conn(self):
def update(self, sha, x):
if self.axis is None:
self.axis = self.default_axis()
self.axis = self.default_axis(x)
# todo: what if elementwise_affine = False?
input_shape = tuple(d for i, d in enumerate(self.feedforward_shapes) if i in self.axis)
self.beta = bm.TrainVar(self.beta_init(input_shape)) if self.bias else None
self.gamma = bm.TrainVar(self.gamma_init(input_shape)) if self.scale else None
self.set_output_shape(self.feedforward_shapes)

def forward(self, ff, **shared_kwargs):
ed = tuple(None if i not in self.axis else slice(None) for i in range(jnp.ndim(ff)))
output = bm.normalize(ff, self.axis, epsilon=self.epsilon)
input_shape = tuple(d for i, d in enumerate(x.shape) if i in self.axis)
self.beta = parameter(self.beta_init, input_shape) if self.bias else None
self.gamma = parameter(self.gamma_init, input_shape) if self.scale else None
if isinstance(self.mode, TrainingMode):
self.beta = bm.TrainVar(self.beta)
self.gamma = bm.TrainVar(self.gamma)

ed = tuple(None if i not in self.axis else slice(None) for i in range(jnp.ndim(x)))
output = bm.normalize(x, self.axis, epsilon=self.epsilon)
if self.bias and self.scale: return self.gamma[ed] * output + self.beta[ed]
if self.bias: return output + self.beta[ed]
if self.scale: return self.gamma[ed] * output
return output

def reset_state(self, batch_size=None):
pass


class GroupNorm(Node):
class GroupNorm(DynamicalSystem):
"""Group normalization layer.

This layer divides channels into groups and normalizes the features within each
@@ -287,8 +303,10 @@ class GroupNorm(Node):
beta_init: Initializer = ZeroInit(),
gamma_init: Initializer = OneInit(),
axis: Union[int, tuple] = None,
mode: Mode = training,
name: str = None,
**kwargs):
super(GroupNorm, self).__init__(**kwargs)
super(GroupNorm, self).__init__(name=name, mode=mode)
self.num_groups = num_groups
self.group_size = group_size
self.epsilon = epsilon
@@ -298,9 +316,9 @@ class GroupNorm(Node):
self.gamma_init = gamma_init if use_scale else ()
self.norm_axis = (axis,) if jnp.isscalar(axis) else axis

def init_ff_conn(self):
num_channels = self.feedforward_shapes[-1]
self.ndim = len(self.feedforward_shapes)
def update(self, sha, x):
num_channels = x.shape[-1]
self.ndim = len(x)

# compute num_groups and group_size
if ((self.num_groups is None and self.group_size is None) or
@@ -327,17 +345,18 @@ class GroupNorm(Node):
# axes for normalization
if self.norm_axis is None:
# default: the first axis (batch dim) and the second-last axis (num_group dim) are excluded
self.norm_axis = tuple(i for i in range(1, len(self.feedforward_shapes) - 1)) + (self.ndim,)
self.norm_axis = tuple(i for i in range(1, len(x.shape) - 1)) + (self.ndim,)

group_shape = self.feedforward_shapes[:-1] + (self.num_groups, self.group_size)
group_shape = x.shape[:-1] + (self.num_groups, self.group_size)
input_shape = tuple(d for i, d in enumerate(group_shape) if i in self.norm_axis)
self.beta = bm.TrainVar(self.beta_init(input_shape)) if self.bias else None
self.gamma = bm.TrainVar(self.gamma_init(input_shape)) if self.scale else None
self.set_output_shape(self.feedforward_shapes)

def forward(self, ff, **shared_kwargs):
group_shape = ff.shape[:-1] + (self.num_groups, self.group_size)
ff_reshape = ff.reshape(group_shape)
self.beta = parameter(self.beta_init, input_shape) if self.bias else None
self.gamma = parameter(self.gamma_init, input_shape) if self.scale else None
if isinstance(self.mode, TrainingMode):
self.beta = bm.TrainVar(self.beta)
self.gamma = bm.TrainVar(self.gamma)

group_shape = x.shape[:-1] + (self.num_groups, self.group_size)
ff_reshape = x.reshape(group_shape)
ed = tuple(None if i not in self.norm_axis else slice(None) for i in range(jnp.ndim(ff_reshape)))
output = bm.normalize(ff_reshape, self.norm_axis, epsilon=self.epsilon)
if self.bias and self.scale:
@@ -346,7 +365,7 @@ class GroupNorm(Node):
output = output + self.beta[ed]
elif self.scale:
output = self.gamma[ed] * output
return output.reshape(ff.shape)
return output.reshape(x.shape)


class InstanceNorm(GroupNorm):
@@ -381,4 +400,4 @@ class InstanceNorm(GroupNorm):
**kwargs):
super(InstanceNorm, self).__init__(group_size=1, epsilon=epsilon, use_bias=use_bias,
use_scale=use_scale, beta_init=beta_init,
gamma_init=gamma_init, axis=axis, **kwargs)
gamma_init=gamma_init, axis=axis, **kwargs)

+ 202
- 0
brainpy/dyn/layers/nvar.py View File

@@ -0,0 +1,202 @@
# -*- coding: utf-8 -*-

from itertools import combinations_with_replacement
from typing import Union, Sequence, List

import jax.numpy as jnp
import numpy as np

import brainpy.math as bm
from brainpy.dyn.base import DynamicalSystem
from brainpy.modes import Mode, NormalMode, BatchingMode, batching, check
from brainpy.tools.checking import (check_integer, check_sequence)

__all__ = [
'NVAR'
]


def _comb(N, k):
r"""The number of combinations of N things taken k at a time.

.. math::

\frac{N!}{(N-k)! k!}

"""
if N > k:
val = 1
for j in range(min(k, N - k)):
val = (val * (N - j)) // (j + 1)
return val
elif N == k:
return 1
else:
return 0


class NVAR(DynamicalSystem):
"""Nonlinear vector auto-regression (NVAR) node.

This class has the following features:

- it supports batch size,
- it supports multiple orders,

Parameters
----------
delay: int
The number of delay step.
order: int, sequence of int
The nonlinear order.
stride: int
The stride to sample linear part vector in the delays.
constant: optional, float
The constant value.

References
----------
.. [1] Gauthier, D.J., Bollt, E., Griffith, A. et al. Next generation
reservoir computing. Nat Commun 12, 5564 (2021).
https://doi.org/10.1038/s41467-021-25801-2

"""

def __init__(
self,
num_in,
delay: int,
order: Union[int, Sequence[int]] = None,
stride: int = 1,
constant: bool = False,
mode: Mode = batching,
name: str = None,
):
super(NVAR, self).__init__(mode=mode, name=name)
check(self.mode, (BatchingMode, NormalMode), self.__class__.__name__)

# parameters
order = tuple() if order is None else order
if not isinstance(order, (tuple, list)):
order = (order,)
self.order = tuple(order)
check_sequence(order, 'order', allow_none=False)
for o in order:
check_integer(o, 'delay', allow_none=False, min_bound=2)
check_integer(delay, 'delay', allow_none=False, min_bound=1)
check_integer(stride, 'stride', allow_none=False, min_bound=1)
assert isinstance(constant, bool), f'Must be an instance of boolean, but got {constant}.'
self.delay = delay
self.stride = stride
self.constant = constant
self.num_delay = 1 + (self.delay - 1) * self.stride
self.num_in = num_in

# delay variables
self.idx = bm.Variable(jnp.asarray([0]))
if isinstance(self.mode, BatchingMode):
batch_size = 1 # first initialize the state with batch size = 1
self.store = bm.Variable(jnp.zeros((self.num_delay, batch_size, self.num_in)), batch_axis=1)
else:
self.store = bm.Variable(jnp.zeros((self.num_delay, self.num_in)))

# linear dimension
self.linear_dim = self.delay * num_in
# For each monomial created in the non-linear part, indices
# of the n components involved, n being the order of the
# monomials. Precompute them to improve efficiency.
self.comb_ids = []
for order in self.order:
assert order >= 2, f'"order" must be a integer >= 2, while we got {order}.'
idx = np.array(list(combinations_with_replacement(np.arange(self.linear_dim), order)))
self.comb_ids.append(jnp.asarray(idx))
# number of non-linear components is (d + n - 1)! / (d - 1)! n!
# i.e. number of all unique monomials of order n made from the
# linear components.
self.nonlinear_dim = sum([len(ids) for ids in self.comb_ids])
# output dimension
self.num_out = int(self.linear_dim + self.nonlinear_dim)
if self.constant:
self.num_out += 1

def reset_state(self, batch_size=None):
"""Reset the node state which depends on batch size."""
self.idx[0] = 0
# To store the last inputs.
# Note, the batch axis is not in the first dimension, so we
# manually handle the state of NVAR, rather return it.
if batch_size is None:
self.store.value = jnp.zeros((self.num_delay, self.num_in))
else:
self.store.value = jnp.zeros((self.num_delay, batch_size, self.num_in))

def update(self, sha, x):
all_parts = []
select_ids = (self.idx[0] - jnp.arange(0, self.num_delay, self.stride)) % self.num_delay
# 1. Store the current input
self.store[self.idx[0]] = x

if isinstance(self.mode, BatchingMode):
# 2. Linear part:
# select all previous inputs, including the current, with strides
linear_parts = jnp.moveaxis(self.store[select_ids], 0, 1) # (num_batch, num_time, num_feature)
linear_parts = jnp.reshape(linear_parts, (linear_parts.shape[0], -1))
# 3. constant
if self.constant:
constant = jnp.ones((linear_parts.shape[0], 1), dtype=x.dtype)
all_parts.append(constant)
all_parts.append(linear_parts)
# 3. Nonlinear part:
# select monomial terms and compute them
for ids in self.comb_ids:
all_parts.append(jnp.prod(linear_parts[:, ids], axis=2))

else:
# 2. Linear part:
# select all previous inputs, including the current, with strides
linear_parts = self.store[select_ids].flatten() # (num_time x num_feature,)
# 3. constant
if self.constant:
constant = jnp.ones((1,), dtype=x.dtype)
all_parts.append(constant)
all_parts.append(linear_parts)
# 3. Nonlinear part:
# select monomial terms and compute them
for ids in self.comb_ids:
all_parts.append(jnp.prod(linear_parts[ids], axis=1))

# 4. Finally
self.idx.value = (self.idx + 1) % self.num_delay
return jnp.concatenate(all_parts, axis=-1)

def get_feature_names(self, for_plot=False) -> List[str]:
"""Get output feature names for transformation.

Parameters
----------
for_plot: bool
Use the feature names for plotting or not? (Default False)
"""
if for_plot:
linear_names = [f'x{i}_t' for i in range(self.num_in)]
else:
linear_names = [f'x{i}(t)' for i in range(self.num_in)]
for di in range(1, self.delay):
linear_names.extend([((f'x{i}_' + r'{t-%d}' % (di * self.stride))
if for_plot else f'x{i}(t-{di * self.stride})')
for i in range(self.num_in)])
nonlinear_names = []
for ids in self.comb_ids:
for id_ in np.asarray(ids):
uniques, counts = np.unique(id_, return_counts=True)
nonlinear_names.append(" ".join(
"%s^%d" % (linear_names[ind], exp) if (exp != 1) else linear_names[ind]
for ind, exp in zip(uniques, counts)
))
if for_plot:
all_names = [f'${n}$' for n in linear_names] + [f'${n}$' for n in nonlinear_names]
else:
all_names = linear_names + nonlinear_names
if self.constant:
all_names = ['1'] + all_names
return all_names

brainpy/nn/nodes/ANN/pooling.py → brainpy/dyn/layers/pooling.py View File

@@ -3,7 +3,8 @@

import jax.lax
import brainpy.math as bm
from brainpy.nn.base import Node
from brainpy.dyn.base import DynamicalSystem
from brainpy.modes import Mode, TrainingMode, NormalMode, training, check

__all__ = [
'Pool',
@@ -13,8 +14,11 @@ __all__ = [
]


class Pool(Node):
def __init__(self, init_v, reduce_fn, window_shape, strides, padding, **kwargs):
class Pool(DynamicalSystem):
def __init__(self, init_v, reduce_fn, window_shape, strides, padding,
mode: Mode = training,
name: str = None,
**kwargs):
"""Pooling functions are implemented using the ReduceWindow XLA op.

Args:
@@ -34,7 +38,7 @@ class Pool(Node):
Returns:
The output of the reduction for each window slice.
"""
super(Pool, self).__init__(**kwargs)
super(Pool, self).__init__(name=name, mode=mode)
self.init_v = init_v
self.reduce_fn = reduce_fn
self.window_shape = window_shape
@@ -55,20 +59,18 @@ class Pool(Node):
padding = ((0, 0),) + padding + ((0, 0),)
self.padding = padding

def init_ff_conn(self):
input_shapes = tuple((0,)) + tuple(d for d in self.feedforward_shapes if d is not None)
def update(self, sha, x):
input_shapes = tuple(d for d in x.shape if d is not None)
assert len(input_shapes) == len(self.dims), f"len({len(input_shapes)}) != len({self.dims})"

padding_vals = jax.lax.padtype_to_pads(input_shapes, self.dims, self.strides, self.padding)
ones = (1,) * len(self.dims)
out_shapes = jax.lax.reduce_window_shape_tuple(
input_shapes, self.dims, self.strides, padding_vals, ones, ones)
# padding_vals = jax.lax.padtype_to_pads(input_shapes, self.dims, self.strides, self.padding)
# ones = (1,) * len(self.dims)
# out_shapes = jax.lax.reduce_window_shape_tuple(
# input_shapes, self.dims, self.strides, padding_vals, ones, ones)
#
# out_shapes = tuple((None,)) + tuple(d for i, d in enumerate(out_shapes) if i != 0)

out_shapes = tuple((None,)) + tuple(d for i, d in enumerate(out_shapes) if i != 0)
self.set_output_shape(out_shapes)

def forward(self, ff, fb=None, **shared_kwargs):
y = jax.lax.reduce_window(ff, self.init_v, self.reduce_fn, self.dims, self.strides, self.padding)
y = jax.lax.reduce_window(x, self.init_v, self.reduce_fn, self.dims, self.strides, self.padding)

return y

@@ -99,8 +101,8 @@ class AvgPool(Pool):
padding=padding
)

def forward(self, ff, fb=None, **shared_kwargs):
y = jax.lax.reduce_window(ff, self.init_v, self.reduce_fn, self.dims, self.strides, self.padding)
def update(self, sha, x):
y = jax.lax.reduce_window(x, self.init_v, self.reduce_fn, self.dims, self.strides, self.padding)
y = y / bm.prod(bm.asarray(self.window_shape))
return y


+ 217
- 0
brainpy/dyn/layers/reservoir.py View File

@@ -0,0 +1,217 @@
# -*- coding: utf-8 -*-

from typing import Optional, Union, Callable, Tuple

import brainpy.math as bm
from brainpy.dyn.base import DynamicalSystem
from brainpy.initialize import Normal, ZeroInit, Initializer, parameter, variable
from brainpy.modes import Mode, TrainingMode, batching
from brainpy.tools.checking import check_float, check_initializer, check_string
from brainpy.tools.others import to_size
from brainpy.types import Array

__all__ = [
'Reservoir',
]


class Reservoir(DynamicalSystem):
r"""Reservoir node, a pool of leaky-integrator neurons
with random recurrent connections [1]_.

Parameters
----------
input_shape: int, tuple of int
The input shape.
num_out: int
The number of reservoir nodes.
Win_initializer: Initializer
The initialization method for the feedforward connections.
Wrec_initializer: Initializer
The initialization method for the recurrent connections.
b_initializer: optional, Array, Initializer
The initialization method for the bias.
leaky_rate: float
A float between 0 and 1.
activation : str, callable, optional
Reservoir activation function.
- If a str, should be a :py:mod:`brainpy.math.activations` function name.
- If a callable, should be an element-wise operator on tensor.
activation_type : str
- If "internal" (default), then leaky integration happens on states transformed
by the activation function:

.. math::

r[n+1] = (1 - \alpha) \cdot r[t] +
\alpha \cdot f(W_{ff} \cdot u[n] + W_{fb} \cdot b[n] + W_{rec} \cdot r[t])

- If "external", then leaky integration happens on internal states of
each neuron, stored in an ``internal_state`` parameter (:math:`x` in
the equation below).
A neuron internal state is the value of its state before applying
the activation function :math:`f`:

.. math::

x[n+1] &= (1 - \alpha) \cdot x[t] +
\alpha \cdot f(W_{ff} \cdot u[n] + W_{rec} \cdot r[t] + W_{fb} \cdot b[n]) \\
r[n+1] &= f(x[n+1])
in_connectivity : float, optional
Connectivity of input neurons, i.e. ratio of input neurons connected
to reservoir neurons. Must be in [0, 1], by default 0.1
rec_connectivity : float, optional
Connectivity of recurrent weights matrix, i.e. ratio of reservoir
neurons connected to other reservoir neurons, including themselves.
Must be in [0, 1], by default 0.1
comp_type: str
The connectivity type, can be "dense" or "sparse".
spectral_radius : float, optional
Spectral radius of recurrent weight matrix, by default None
noise_rec : float, optional
Gain of noise applied to reservoir internal states, by default 0.0
noise_in : float, optional
Gain of noise applied to feedforward signals, by default 0.0
noise_type : optional, str, callable
Distribution of noise. Must be a random variable generator
distribution (see :py:class:`brainpy.math.random.RandomState`),
by default "normal".
seed: optional, int
The seed for random sampling in this node.

References
----------
.. [1] Lukoševičius, Mantas. "A practical guide to applying echo state networks."
Neural networks: Tricks of the trade. Springer, Berlin, Heidelberg, 2012. 659-686.
"""

def __init__(
self,
input_shape: Union[int, Tuple[int]],
num_out: int,
leaky_rate: float = 0.3,
activation: Union[str, Callable] = 'tanh',
activation_type: str = 'internal',
Win_initializer: Union[Initializer, Callable, Array] = Normal(scale=0.1),
Wrec_initializer: Union[Initializer, Callable, Array] = Normal(scale=0.1),
b_initializer: Optional[Union[Initializer, Callable, Array]] = ZeroInit(),
in_connectivity: float = 0.1,
rec_connectivity: float = 0.1,
comp_type='dense',
spectral_radius: Optional[float] = None,
noise_in: float = 0.,
noise_rec: float = 0.,
noise_type: str = 'normal',
seed: Optional[int] = None,
mode: Mode = batching,
name: str = None
):
super(Reservoir, self).__init__(mode=mode, name=name)

# parameters
input_shape = to_size(input_shape)
if input_shape[0] is None:
input_shape = input_shape[1:]
self.input_shape = input_shape
self.output_shape = input_shape[:-1] + (num_out,)
self.num_unit = num_out
assert num_out > 0, f'Must be a positive integer, but we got {num_out}'
self.leaky_rate = leaky_rate
check_float(leaky_rate, 'leaky_rate', 0., 1.)
self.activation = bm.activations.get(activation)
self.activation_type = activation_type
check_string(activation_type, 'activation_type', ['internal', 'external'])
self.rng = bm.random.RandomState(seed)
check_float(spectral_radius, 'spectral_radius', allow_none=True)
self.spectral_radius = spectral_radius

# initializations
check_initializer(Win_initializer, 'ff_initializer', allow_none=False)
check_initializer(Wrec_initializer, 'rec_initializer', allow_none=False)
check_initializer(b_initializer, 'bias_initializer', allow_none=True)
self._Win_initializer = Win_initializer
self._Wrec_initializer = Wrec_initializer
self._b_initializer = b_initializer

# connectivity
check_float(in_connectivity, 'ff_connectivity', 0., 1.)
check_float(rec_connectivity, 'rec_connectivity', 0., 1.)
self.ff_connectivity = in_connectivity
self.rec_connectivity = rec_connectivity
check_string(comp_type, 'conn_type', ['dense', 'sparse'])
self.comp_type = comp_type

# noises
check_float(noise_in, 'noise_ff')
check_float(noise_rec, 'noise_rec')
self.noise_ff = noise_in
self.noise_rec = noise_rec
self.noise_type = noise_type
check_string(noise_type, 'noise_type', ['normal', 'uniform'])

# initialize feedforward weights
weight_shape = (input_shape[-1], self.num_unit)
self.Wff_shape = weight_shape
self.Win = parameter(self._Win_initializer, weight_shape)
if self.ff_connectivity < 1.:
conn_mat = self.rng.random(weight_shape) > self.ff_connectivity
self.Win[conn_mat] = 0.
if self.comp_type == 'sparse' and self.ff_connectivity < 1.:
self.ff_pres, self.ff_posts = bm.where(bm.logical_not(conn_mat))
self.Win = self.Win[self.ff_pres, self.ff_posts]
if isinstance(self.mode, TrainingMode):
self.Win = bm.TrainVar(self.Win)

# initialize recurrent weights
recurrent_shape = (self.num_unit, self.num_unit)
self.Wrec = parameter(self._Wrec_initializer, recurrent_shape)
if self.rec_connectivity < 1.:
conn_mat = self.rng.random(recurrent_shape) > self.rec_connectivity
self.Wrec[conn_mat] = 0.
if self.spectral_radius is not None:
current_sr = max(abs(bm.linalg.eig(self.Wrec)[0]))
self.Wrec *= self.spectral_radius / current_sr
if self.comp_type == 'sparse' and self.rec_connectivity < 1.:
self.rec_pres, self.rec_posts = bm.where(bm.logical_not(conn_mat))
self.Wrec = self.Wrec[self.rec_pres, self.rec_posts]
self.bias = parameter(self._b_initializer, (self.num_unit,))
if isinstance(self.mode, TrainingMode):
self.Wrec = bm.TrainVar(self.Wrec)
self.bias = None if (self.bias is None) else bm.TrainVar(self.bias)

# initialize state
self.state = variable(bm.zeros, mode, self.output_shape)

def reset_state(self, batch_size=None):
self.state.value = variable(bm.zeros, batch_size, self.output_shape)

def update(self, sha, x):
"""Feedforward output."""
# inputs
x = bm.concatenate(x, axis=-1)
if self.noise_ff > 0: x += self.noise_ff * self.rng.uniform(-1, 1, x.shape)
if self.comp_type == 'sparse' and self.ff_connectivity < 1.:
sparse = {'data': self.Win,
'index': (self.ff_pres, self.ff_posts),
'shape': self.Wff_shape}
hidden = bm.sparse_matmul(x, sparse)
else:
hidden = bm.dot(x, self.Win)
# recurrent
if self.comp_type == 'sparse' and self.rec_connectivity < 1.:
sparse = {'data': self.Wrec,
'index': (self.rec_pres, self.rec_posts),
'shape': (self.num_unit, self.num_unit)}
hidden += bm.sparse_matmul(self.state, sparse)
else:
hidden += bm.dot(self.state, self.Wrec)
if self.activation_type == 'internal':
hidden = self.activation(hidden)
if self.noise_rec > 0.:
hidden += self.noise_rec * self.rng.uniform(-1, -1, self.state.shape)
# new state/output
state = (1 - self.leaky_rate) * self.state + self.leaky_rate * hidden
if self.activation_type == 'external':
state = self.activation(state)
self.state.value = state
return state

+ 425
- 0
brainpy/dyn/layers/rnncells.py View File

@@ -0,0 +1,425 @@
# -*- coding: utf-8 -*-


from typing import Union, Callable

import brainpy.math as bm
from brainpy.dyn.base import DynamicalSystem
from brainpy.initialize import (XavierNormal,
ZeroInit,
Orthogonal,
parameter,
variable,
Initializer)
from brainpy.modes import Mode, TrainingMode, training
from brainpy.tools.checking import (check_integer,
check_initializer)
from brainpy.types import Array

__all__ = [
'VanillaRNN',
'GRU',
'LSTM',
]


class RecurrentCell(DynamicalSystem):
def __init__(self,
num_out: int,
state_initializer: Union[Array, Callable, Initializer] = ZeroInit(),
mode: Mode = training,
train_state: bool = False,
name: str = None):
super(RecurrentCell, self).__init__(mode=mode, name=name)

# parameters
self._state_initializer = state_initializer
check_initializer(state_initializer, 'state_initializer', allow_none=False)
self.num_out = num_out
check_integer(num_out, 'num_out', min_bound=1, allow_none=False)
self.train_state = train_state


class VanillaRNN(RecurrentCell):
r"""Basic fully-connected RNN core.

Given :math:`x_t` and the previous hidden state :math:`h_{t-1}` the
core computes

.. math::

h_t = \mathrm{ReLU}(w_i x_t + b_i + w_h h_{t-1} + b_h)

The output is equal to the new state, :math:`h_t`.


Parameters
----------
num_out: int
The number of hidden unit in the node.
state_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray
The state initializer.
Wi_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray
The input weight initializer.
Wh_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray
The hidden weight initializer.
b_initializer: optional, callable, Initializer, bm.ndarray, jax.numpy.ndarray
The bias weight initializer.
activation: str, callable
The activation function. It can be a string or a callable function.
See ``brainpy.math.activations`` for more details.
trainable: bool
Whether set the node is trainable.

"""

def __init__(
self,
num_in: int,
num_out: int,
state_initializer: Union[Array, Callable, Initializer] = ZeroInit(),
Wi_initializer: Union[Array, Callable, Initializer] = XavierNormal(),
Wh_initializer: Union[Array, Callable, Initializer] = XavierNormal(),
b_initializer: Union[Array, Callable, Initializer] = ZeroInit(),
activation: str = 'relu',
mode: Mode = training,
train_state: bool = False,
name: str = None,
):
super(VanillaRNN, self).__init__(num_out=num_out,
state_initializer=state_initializer,
train_state=train_state,
mode=mode,
name=name)

# parameters
self.num_in = num_in
check_integer(num_in, 'num_in', min_bound=1, allow_none=False)

# initializers
self._Wi_initializer = Wi_initializer
self._Wh_initializer = Wh_initializer
self._b_initializer = b_initializer
check_initializer(Wi_initializer, 'wi_initializer', allow_none=False)
check_initializer(Wh_initializer, 'wh_initializer', allow_none=False)
check_initializer(b_initializer, 'b_initializer', allow_none=True)

# activation function
self.activation = bm.activations.get(activation)

# weights
self.Wi = parameter(self._Wi_initializer, (num_in, self.num_out))
self.Wh = parameter(self._Wh_initializer, (self.num_out, self.num_out))
self.b = parameter(self._b_initializer, (self.num_out,))
if isinstance(self.mode, TrainingMode):
self.Wi = bm.TrainVar(self.Wi)
self.Wh = bm.TrainVar(self.Wh)
self.b = None if (self.b is None) else bm.TrainVar(self.b)

# state
self.state = variable(bm.zeros, mode, self.num_out)
if train_state and isinstance(self.mode, TrainingMode):
self.state2train = bm.TrainVar(parameter(state_initializer, (self.num_out,), allow_none=False))
self.state[:] = self.state2train

def reset_state(self, batch_size=None):
self.state.value = parameter(self._state_initializer, (batch_size, self.num_out), allow_none=False)
if self.train_state:
self.state2train.value = parameter(self._state_initializer, self.num_out, allow_none=False)
self.state[:] = self.state2train

def update(self, sha, x):
h = x @ self.Wi
h += self.state.value @ self.Wh
if self.b is not None:
h += self.b
self.state.value = self.activation(h)
return self.state.value


class GRU(RecurrentCell):
r"""Gated Recurrent Unit.

The implementation is based on (Chung, et al., 2014) [1]_ with biases.

Given :math:`x_t` and the previous state :math:`h_{t-1}` the core computes

.. math::

\begin{array}{ll}
z_t &= \sigma(W_{iz} x_t + W_{hz} h_{t-1} + b_z) \\
r_t &= \sigma(W_{ir} x_t + W_{hr} h_{t-1} + b_r) \\
a_t &= \tanh(W_{ia} x_t + W_{ha} (r_t \bigodot h_{t-1}) + b_a) \\
h_t &= (1 - z_t) \bigodot h_{t-1} + z_t \bigodot a_t
\end{array}

where :math:`z_t` and :math:`r_t` are reset and update gates.

The output is equal to the new hidden state, :math:`h_t`.

Warning: Backwards compatibility of GRU weights is currently unsupported.

Parameters
----------
num_out: int
The number of hidden unit in the node.
state_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray
The state initializer.
Wi_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray
The input weight initializer.
Wh_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray
The hidden weight initializer.
b_initializer: optional, callable, Initializer, bm.ndarray, jax.numpy.ndarray
The bias weight initializer.
activation: str, callable
The activation function. It can be a string or a callable function.
See ``brainpy.math.activations`` for more details.
trainable: bool
Whether set the node is trainable.

References
----------
.. [1] Chung, J., Gulcehre, C., Cho, K. and Bengio, Y., 2014. Empirical
evaluation of gated recurrent neural networks on sequence modeling.
arXiv preprint arXiv:1412.3555.
"""

def __init__(
self,
num_in: int,
num_out: int,
Wi_initializer: Union[Array, Callable, Initializer] = Orthogonal(),
Wh_initializer: Union[Array, Callable, Initializer] = Orthogonal(),
b_initializer: Union[Array, Callable, Initializer] = ZeroInit(),
state_initializer: Union[Array, Callable, Initializer] = ZeroInit(),
activation: str = 'tanh',
mode: Mode = training,
train_state: bool = False,
name: str = None,
):
super(GRU, self).__init__(num_out=num_out,
state_initializer=state_initializer,
train_state=train_state,
mode=mode,
name=name)
# parameters
self.num_in = num_in
check_integer(num_in, 'num_in', min_bound=1, allow_none=False)

# initializers
self._Wi_initializer = Wi_initializer
self._Wh_initializer = Wh_initializer
self._b_initializer = b_initializer
check_initializer(Wi_initializer, 'Wi_initializer', allow_none=False)
check_initializer(Wh_initializer, 'Wh_initializer', allow_none=False)
check_initializer(b_initializer, 'b_initializer', allow_none=True)

# activation function
self.activation = bm.activations.get(activation)

# weights
self.Wi = parameter(self._Wi_initializer, (num_in, self.num_out * 3))
self.Wh = parameter(self._Wh_initializer, (self.num_out, self.num_out * 3))
self.b = parameter(self._b_initializer, (self.num_out * 3,))
if isinstance(self.mode, TrainingMode):
self.Wi = bm.TrainVar(self.Wi)
self.Wh = bm.TrainVar(self.Wh)
self.b = bm.TrainVar(self.b) if (self.b is not None) else None

# state
self.state = variable(bm.zeros, mode, self.num_out)
if train_state and isinstance(self.mode, TrainingMode):
self.state2train = bm.TrainVar(parameter(state_initializer, (self.num_out,), allow_none=False))
self.state[:] = self.state2train

def reset_state(self, batch_size=None):
self.state.value = parameter(self._state_initializer, (batch_size, self.num_out), allow_none=False)
if self.train_state:
self.state2train.value = parameter(self._state_initializer, self.num_out, allow_none=False)
self.state[:] = self.state2train

def update(self, sha, x):
gates_x = bm.matmul(x, self.Wi)
zr_x, a_x = bm.split(gates_x, indices_or_sections=[2 * self.num_out], axis=-1)
w_h_z, w_h_a = bm.split(self.Wh, indices_or_sections=[2 * self.num_out], axis=-1)
zr_h = bm.matmul(self.state, w_h_z)
zr = zr_x + zr_h
has_bias = (self.b is not None)
if has_bias:
b_z, b_a = bm.split(self.b, indices_or_sections=[2 * self.num_out], axis=0)
zr += bm.broadcast_to(b_z, zr_h.shape)
z, r = bm.split(bm.sigmoid(zr), indices_or_sections=2, axis=-1)
a_h = bm.matmul(r * self.state, w_h_a)
if has_bias:
a = self.activation(a_x + a_h + bm.broadcast_to(b_a, a_h.shape))
else:
a = self.activation(a_x + a_h)
next_state = (1 - z) * self.state + z * a
self.state.value = next_state
return self.state.value


class LSTM(RecurrentCell):
r"""Long short-term memory (LSTM) RNN core.

The implementation is based on (zaremba, et al., 2014) [1]_. Given
:math:`x_t` and the previous state :math:`(h_{t-1}, c_{t-1})` the core
computes

.. math::

\begin{array}{ll}
i_t = \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) \\
f_t = \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) \\
g_t = \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) \\
o_t = \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) \\
c_t = f_t c_{t-1} + i_t g_t \\
h_t = o_t \tanh(c_t)
\end{array}

where :math:`i_t`, :math:`f_t`, :math:`o_t` are input, forget and
output gate activations, and :math:`g_t` is a vector of cell updates.

The output is equal to the new hidden, :math:`h_t`.

Notes
-----

Forget gate initialization: Following (Jozefowicz, et al., 2015) [2]_ we add 1.0
to :math:`b_f` after initialization in order to reduce the scale of forgetting in
the beginning of the training.


Parameters
----------
num_out: int
The number of hidden unit in the node.
state_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray
The state initializer.
Wi_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray
The input weight initializer.
Wh_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray
The hidden weight initializer.
b_initializer: optional, callable, Initializer, bm.ndarray, jax.numpy.ndarray
The bias weight initializer.
activation: str, callable
The activation function. It can be a string or a callable function.
See ``brainpy.math.activations`` for more details.
trainable: bool
Whether set the node is trainable.

References
----------

.. [1] Zaremba, Wojciech, Ilya Sutskever, and Oriol Vinyals. "Recurrent neural
network regularization." arXiv preprint arXiv:1409.2329 (2014).
.. [2] Jozefowicz, Rafal, Wojciech Zaremba, and Ilya Sutskever. "An empirical
exploration of recurrent network architectures." In International conference
on machine learning, pp. 2342-2350. PMLR, 2015.
"""

def __init__(
self,
num_in: int,
num_out: int,
Wi_initializer: Union[Array, Callable, Initializer] = XavierNormal(),
Wh_initializer: Union[Array, Callable, Initializer] = XavierNormal(),
b_initializer: Union[Array, Callable, Initializer] = ZeroInit(),
state_initializer: Union[Array, Callable, Initializer] = ZeroInit(),
activation: str = 'tanh',
mode: Mode = training,
train_state: bool = False,
name: str = None,
):
super(LSTM, self).__init__(num_out=num_out,
state_initializer=state_initializer,
train_state=train_state,
mode=mode,
name=name)
# parameters
self.num_in = num_in
check_integer(num_in, 'num_in', min_bound=1, allow_none=False)

# initializers
self._state_initializer = state_initializer
self._Wi_initializer = Wi_initializer
self._Wh_initializer = Wh_initializer
self._b_initializer = b_initializer
check_initializer(Wi_initializer, 'wi_initializer', allow_none=False)
check_initializer(Wh_initializer, 'wh_initializer', allow_none=False)
check_initializer(b_initializer, 'b_initializer', allow_none=True)
check_initializer(state_initializer, 'state_initializer', allow_none=False)

# activation function
self.activation = bm.activations.get(activation)

# weights
self.Wi = parameter(self._Wi_initializer, (num_in, self.num_out * 4))
self.Wh = parameter(self._Wh_initializer, (self.num_out, self.num_out * 4))
self.b = parameter(self._b_initializer, (self.num_out * 4,))
if isinstance(self.mode, TrainingMode):
self.Wi = bm.TrainVar(self.Wi)
self.Wh = bm.TrainVar(self.Wh)
self.b = None if (self.b is None) else bm.TrainVar(self.b)

# state
self.state = variable(bm.zeros, mode, self.num_out * 2)
if train_state and isinstance(self.mode, TrainingMode):
self.state2train = bm.TrainVar(parameter(state_initializer, (self.num_out * 2,), allow_none=False))
self.state[:] = self.state2train

def reset_state(self, batch_size=None):
self.state.value = parameter(self._state_initializer, (batch_size, self.num_out * 2), allow_none=False)
if self.train_state:
self.state2train.value = parameter(self._state_initializer, self.num_out * 2, allow_none=False)
self.state[:] = self.state2train

def update(self, sha, x):
h, c = bm.split(self.state, 2, axis=-1)
gated = x @ self.Wi
if self.b is not None:
gated += self.b
gated += h @ self.Wh
i, g, f, o = bm.split(gated, indices_or_sections=4, axis=-1)
c = bm.sigmoid(f + 1.) * c + bm.sigmoid(i) * self.activation(g)
h = bm.sigmoid(o) * self.activation(c)
self.state.value = bm.concatenate([h, c], axis=-1)
return h

@property
def h(self):
"""Hidden state."""
return bm.split(self.state, 2, axis=-1)[0]

@h.setter
def h(self, value):
if self.state is None:
raise ValueError('Cannot set "h" state. Because the state is not initialized.')
self.state[:self.state.shape[0] // 2, :] = value

@property
def c(self):
"""Memory cell."""
return bm.split(self.state, 2, axis=-1)[1]

@c.setter
def c(self, value):
if self.state is None:
raise ValueError('Cannot set "c" state. Because the state is not initialized.')
self.state[self.state.shape[0] // 2:, :] = value


class ConvNDLSTM(DynamicalSystem):
pass


class Conv1DLSTM(ConvNDLSTM):
pass


class Conv2DLSTM(ConvNDLSTM):
pass


class Conv3DLSTM(ConvNDLSTM):
pass

+ 95
- 0
brainpy/dyn/layers/tests/test_conv.py View File

@@ -0,0 +1,95 @@
# -*- coding: utf-8 -*-
import random

import pytest
from unittest import TestCase
import brainpy as bp
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt


class TestConv(TestCase):
def test_Conv2D_img(self):
class Convnet(bp.dyn.DynamicalSystem):
def __init__(self):
super(Convnet, self).__init__()
self.conv = bp.layers.Conv2D(in_channels=4, out_channels=32, kernel_size=(3, 3),
strides=(1, 1), padding='SAME', groups=1)

def update(self, shared, x):
x = self.conv(shared, x)
return x

img = jnp.zeros((2, 200, 198, 4))
for k in range(4):
x = 30 + 60 * k
y = 20 + 60 * k
img = img.at[0, x:x + 10, y:y + 10, k].set(1.0)
img = img.at[1, x:x + 20, y:y + 20, k].set(3.0)

net = Convnet()
out = net(None, img)
print("out shape: ", out.shape)
# print("First output channel:")
# plt.figure(figsize=(10, 10))
# plt.imshow(np.array(img)[0, :, :, 0])
# plt.show()

def test_conv1D(self):
class Convnet(bp.dyn.DynamicalSystem):
def __init__(self):
super(Convnet, self).__init__()
self.conv = bp.layers.Conv1D(in_channels=3, out_channels=32, kernel_size=(3,))

def update(self, shared, x):
x = self.conv(shared, x)
return x

model = Convnet()
input = bp.math.ones((2, 5, 3))

out = model(None, input)
print("out shape: ", out.shape)
# print("First output channel:")
# plt.figure(figsize=(10, 10))
# plt.imshow(np.array(out)[0, :, :])
# plt.show()

def test_conv2D(self):
class Convnet(bp.dyn.DynamicalSystem):
def __init__(self):
super(Convnet, self).__init__()
self.conv = bp.layers.Conv2D(in_channels=3, out_channels=32, kernel_size=(3, 3))

def update(self, shared, x):
x = self.conv(shared, x)
return x

model = Convnet()

input = bp.math.ones((2, 5, 5, 3))

out = model(None, input)
print("out shape: ", out.shape)
# print("First output channel:")
# plt.figure(figsize=(10, 10))
# plt.imshow(np.array(out)[0, :, :, 31])
# plt.show()

def test_conv3D(self):
class Convnet(bp.dyn.DynamicalSystem):
def __init__(self):
super(Convnet, self).__init__()
self.conv = bp.layers.Conv3D(in_channels=3, out_channels=32, kernel_size=(3, 3, 3))

def update(self, shared, x):
x = self.conv(shared, x)
return x

model = Convnet()

input = bp.math.ones((2, 5, 5, 5, 3))

out = model(None, input)
print("out shape: ", out.shape)

+ 201
- 0
brainpy/dyn/layers/tests/test_normalization.py View File

@@ -0,0 +1,201 @@
# -*- coding: utf-8 -*-


from unittest import TestCase

import brainpy as bp


class TestBatchNorm1d(TestCase):
def test_batchnorm1d1(self):
class BatchNormNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(BatchNormNet, self).__init__()
self.norm = bp.dyn.layers.BatchNorm1d(axis=(0, 1, 2))

def update(self, shared, x):
x = self.norm(shared, x)
return x

inputs = bp.math.ones((2, 3, 4))
inputs[0, 0, :] = 2.
inputs[0, 1, 0] = 5.
print(inputs)
model = BatchNormNet()
shared = {'fit': False}
print(model(shared, inputs))

def test_batchnorm1d2(self):
class BatchNormNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(BatchNormNet, self).__init__()
self.norm = bp.dyn.layers.BatchNorm1d()
self.dense = bp.dyn.layers.Dense(num_in=4, num_out=4)

def update(self, shared, x):
x = self.norm(shared, x)
x = self.dense(shared, x)
return x

inputs = bp.math.ones((2, 4))
inputs[0, :] = 2.
print(inputs)
model = BatchNormNet()
shared = {'fit': False}
print(model(shared, inputs))


class TestBatchNorm2d(TestCase):
def test_batchnorm2d(self):
class BatchNormNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(BatchNormNet, self).__init__()
self.norm = bp.dyn.layers.BatchNorm2d()

def update(self, shared, x):
x = self.norm(shared, x)
return x

inputs = bp.math.ones((10, 32, 32, 3))
inputs[0, 1, :, :] = 2.
print(inputs)
model = BatchNormNet()
shared = {'fit': False}
print(model(shared, inputs))


class TestBatchNorm3d(TestCase):
def test_batchnorm3d(self):
class BatchNormNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(BatchNormNet, self).__init__()
self.norm = bp.dyn.layers.BatchNorm3d()

def update(self, shared, x):
x = self.norm(shared, x)
return x

inputs = bp.math.ones((10, 32, 32, 16, 3))
print(inputs)
model = BatchNormNet()
shared = {'fit': False}
print(model(shared, inputs))


class TestBatchNorm(TestCase):
def test_batchnorm1(self):
class BatchNormNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(BatchNormNet, self).__init__()
self.norm = bp.dyn.layers.BatchNorm(axis=(0, 2), use_bias=False) # channel axis: 1

def update(self, shared, x):
x = self.norm(shared, x)
return x

inputs = bp.math.ones((2, 3, 4))
inputs[0, 0, :] = 2.
inputs[0, 1, 0] = 5.
print(inputs)
model = BatchNormNet()
shared = {'fit': False}
print(model(shared, inputs))

def test_batchnorm2(self):
class BatchNormNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(BatchNormNet, self).__init__()
self.norm = bp.dyn.layers.BatchNorm(axis=(0, 2)) # channel axis: 1
self.dense = bp.dyn.layers.Dense(num_in=12, num_out=2)

def update(self, shared, x):
x = self.norm(shared, x)
x = x.reshape(-1, 12)
x = self.dense(shared, x)
return x

inputs = bp.math.ones((2, 3, 4))
inputs[0, 0, :] = 2.
inputs[0, 1, 0] = 5.
# print(inputs)
model = BatchNormNet()
shared = {'fit': False}
print(model(shared, inputs))


class TestLayerNorm(TestCase):
def test_layernorm1(self):
class LayerNormNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(LayerNormNet, self).__init__()
self.norm = bp.dyn.layers.LayerNorm()

def update(self, shared, x):
x = self.norm(shared, x)
return x

inputs = bp.math.ones((2, 3, 4))
inputs[0, 0, :] = 2.
inputs[0, 1, 0] = 5.
print(inputs)
model = LayerNormNet()
shared = {'fit': False}
print(model(shared, inputs))

def test_layernorm2(self):
class LayerNormNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(LayerNormNet, self).__init__()
self.norm = bp.dyn.layers.LayerNorm(axis=2)

def update(self, shared, x):
x = self.norm(shared, x)
return x

inputs = bp.math.ones((2, 3, 4))
inputs[0, 0, :] = 2.
inputs[0, 1, 0] = 5.
print(inputs)
model = LayerNormNet()
shared = {'fit': False}
print(model(shared, inputs))


class TestInstanceNorm(TestCase):
def test_instancenorm(self):
class InstanceNormNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(InstanceNormNet, self).__init__()
self.norm = bp.dyn.layers.InstanceNorm()

def update(self, shared, x):
x = self.norm(shared, x)
return x

inputs = bp.math.ones((2, 3, 4))
inputs[0, 0, :] = 2.
inputs[0, 1, 0] = 5.
print(inputs)
model = InstanceNormNet()
shared = {'fit': False}
print(model(shared, inputs))


class TestGroupNorm(TestCase):
def test_groupnorm1(self):
class GroupNormNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(GroupNormNet, self).__init__()
self.norm = bp.dyn.layers.GroupNorm(num_groups=2)

def update(self, shared, x):
x = self.norm(shared, x)
return x

inputs = bp.math.ones((2, 3, 4))
inputs[0, 0, :] = 2.
inputs[0, 1, 0] = 5.
print(inputs)
model = GroupNormNet()
shared = {'fit': False}
print(model(shared, inputs))

+ 70
- 0
brainpy/dyn/layers/tests/test_pooling.py View File

@@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-
import random

import pytest
from unittest import TestCase
import brainpy as bp
import jax.numpy as jnp
import jax
import numpy as np


class TestPool(TestCase):
def test_maxpool(self):
class MaxPoolNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(MaxPoolNet, self).__init__()
self.maxpool = bp.dyn.layers.MaxPool((2, 2))

def update(self, sha, x):
x = self.maxpool(sha, x)
return x

x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
shared = {'fit': False}
net = MaxPoolNet()
y = net(shared, x)
print("out shape: ", y.shape)
expected_y = jnp.array([
[4., 5.],
[7., 8.],
]).reshape((1, 2, 2, 1))
np.testing.assert_allclose(y, expected_y)

def test_minpool(self):
class MinPoolNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(MinPoolNet, self).__init__()
self.maxpool = bp.dyn.layers.MinPool((2, 2))

def update(self, sha, x):
x = self.maxpool(sha, x)
return x

x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
shared = {'fit': False}
net = MinPoolNet()
y = net(shared, x)
print("out shape: ", y.shape)
expected_y = jnp.array([
[0., 1.],
[3., 4.],
]).reshape((1, 2, 2, 1))
np.testing.assert_allclose(y, expected_y)

def test_avgpool(self):
class AvgPoolNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(AvgPoolNet, self).__init__()
self.maxpool = bp.dyn.layers.AvgPool((2, 2))

def update(self, sha, x):
x = self.maxpool(sha, x)
return x

x = jnp.full((1, 3, 3, 1), 2.)
shared = {'fit': False}
net = AvgPoolNet()
y = net(shared, x)
print("out shape: ", y.shape)
np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.))

+ 1
- 0
brainpy/dyn/networks/__init__.py View File

@@ -0,0 +1 @@
# -*- coding: utf-8 -*-

+ 25
- 0
brainpy/dyn/networks/cann.py View File

@@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-


from brainpy.dyn.base import NeuGroup

__all__ = [
'WuCANN1D',
'WuCANN2D',
]


class WuCANN1D(NeuGroup):
pass


class WuCANN2D(NeuGroup):
pass


class ACANN_1D(NeuGroup):
pass


class ACANN_2D(NeuGroup):
pass

+ 2
- 1
brainpy/dyn/neurons/__init__.py View File

@@ -2,5 +2,6 @@

from .biological_models import *
from .fractional_models import *
from .input_models import *
from .reduced_models import *
from .input_groups import *
from .noise_groups import *

+ 458
- 185
brainpy/dyn/neurons/biological_models.py View File

@@ -1,19 +1,22 @@
# -*- coding: utf-8 -*-

from typing import Union, Callable
from typing import Union, Callable, Optional

import brainpy.math as bm
from brainpy.dyn.base import NeuGroup
from brainpy.initialize import OneInit, Uniform, Initializer, init_param
from brainpy.initialize import OneInit, Uniform, Initializer, parameter, noise as init_noise, variable
from brainpy.integrators.joint_eq import JointEq
from brainpy.integrators.ode import odeint
from brainpy.integrators.sde import sdeint
from brainpy.modes import Mode, BatchingMode, TrainingMode, NormalMode, normal, check
from brainpy.tools.checking import check_initializer
from brainpy.types import Shape, Tensor
from brainpy.types import Shape, Array

__all__ = [
'HH',
'MorrisLecar',
'PinskyRinzelModel',
'WangBuzsakiModel',
]


@@ -191,38 +194,48 @@ class HH(NeuGroup):
def __init__(
self,
size: Shape,
ENa: Union[float, Tensor, Initializer, Callable] = 50.,
gNa: Union[float, Tensor, Initializer, Callable] = 120.,
EK: Union[float, Tensor, Initializer, Callable] = -77.,
gK: Union[float, Tensor, Initializer, Callable] = 36.,
EL: Union[float, Tensor, Initializer, Callable] = -54.387,
gL: Union[float, Tensor, Initializer, Callable] = 0.03,
V_th: Union[float, Tensor, Initializer, Callable] = 20.,
C: Union[float, Tensor, Initializer, Callable] = 1.0,
V_initializer: Union[Initializer, Callable, Tensor] = Uniform(-70, -60.),
m_initializer: Union[Initializer, Callable, Tensor] = OneInit(0.5),
h_initializer: Union[Initializer, Callable, Tensor] = OneInit(0.6),
n_initializer: Union[Initializer, Callable, Tensor] = OneInit(0.32),
keep_size: bool = False,
ENa: Union[float, Array, Initializer, Callable] = 50.,
gNa: Union[float, Array, Initializer, Callable] = 120.,
EK: Union[float, Array, Initializer, Callable] = -77.,
gK: Union[float, Array, Initializer, Callable] = 36.,
EL: Union[float, Array, Initializer, Callable] = -54.387,
gL: Union[float, Array, Initializer, Callable] = 0.03,
V_th: Union[float, Array, Initializer, Callable] = 20.,
C: Union[float, Array, Initializer, Callable] = 1.0,
V_initializer: Union[Initializer, Callable, Array] = Uniform(-70, -60.),
m_initializer: Optional[Union[Initializer, Callable, Array]] = None,
h_initializer: Optional[Union[Initializer, Callable, Array]] = None,
n_initializer: Optional[Union[Initializer, Callable, Array]] = None,
noise: Union[float, Array, Initializer, Callable] = None,
method: str = 'exp_auto',
name: str = None
name: str = None,

# training parameter
mode: Mode = normal,
):
# initialization
super(HH, self).__init__(size=size, name=name)
super(HH, self).__init__(size=size,
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (BatchingMode, NormalMode), self.__class__.__name__)

# parameters
self.ENa = init_param(ENa, self.num, allow_none=False)
self.EK = init_param(EK, self.num, allow_none=False)
self.EL = init_param(EL, self.num, allow_none=False)
self.gNa = init_param(gNa, self.num, allow_none=False)
self.gK = init_param(gK, self.num, allow_none=False)
self.gL = init_param(gL, self.num, allow_none=False)
self.C = init_param(C, self.num, allow_none=False)
self.V_th = init_param(V_th, self.num, allow_none=False)
self.ENa = parameter(ENa, self.varshape, allow_none=False)
self.EK = parameter(EK, self.varshape, allow_none=False)
self.EL = parameter(EL, self.varshape, allow_none=False)
self.gNa = parameter(gNa, self.varshape, allow_none=False)
self.gK = parameter(gK, self.varshape, allow_none=False)
self.gL = parameter(gL, self.varshape, allow_none=False)
self.C = parameter(C, self.varshape, allow_none=False)
self.V_th = parameter(V_th, self.varshape, allow_none=False)
self.noise = init_noise(noise, self.varshape, num_vars=4)

# initializers
check_initializer(m_initializer, 'm_initializer', allow_none=False)
check_initializer(h_initializer, 'h_initializer', allow_none=False)
check_initializer(n_initializer, 'n_initializer', allow_none=False)
check_initializer(m_initializer, 'm_initializer', allow_none=True)
check_initializer(h_initializer, 'h_initializer', allow_none=True)
check_initializer(n_initializer, 'n_initializer', allow_none=True)
check_initializer(V_initializer, 'V_initializer', allow_none=False)
self._m_initializer = m_initializer
self._h_initializer = h_initializer
@@ -230,41 +243,62 @@ class HH(NeuGroup):
self._V_initializer = V_initializer

# variables
self.m = bm.Variable(init_param(self._m_initializer, (self.num,)))
self.h = bm.Variable(init_param(self._h_initializer, (self.num,)))
self.n = bm.Variable(init_param(self._n_initializer, (self.num,)))
self.V = bm.Variable(init_param(self._V_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.V = variable(self._V_initializer, mode, self.varshape)
if self._m_initializer is None:
self.m = bm.Variable(self.m_inf(self.V.value))
else:
self.m = variable(self._m_initializer, mode, self.varshape)
if self._h_initializer is None:
self.h = bm.Variable(self.h_inf(self.V.value))
else:
self.h = variable(self._h_initializer, mode, self.varshape)
if self._n_initializer is None:
self.n = bm.Variable(self.n_inf(self.V.value))
else:
self.n = variable(self._n_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)

# integral
self.integral = odeint(method=method, f=self.derivative)

def reset(self):
self.m.value = init_param(self._m_initializer, (self.num,))
self.h.value = init_param(self._h_initializer, (self.num,))
self.n.value = init_param(self._n_initializer, (self.num,))
self.V.value = init_param(self._V_initializer, (self.num,))
self.input[:] = 0
self.spike[:] = False

def dm(self, m, t, V):
alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
beta = 4.0 * bm.exp(-(V + 65) / 18)
dmdt = alpha * (1 - m) - beta * m
return dmdt

def dh(self, h, t, V):
alpha = 0.07 * bm.exp(-(V + 65) / 20.)
beta = 1 / (1 + bm.exp(-(V + 35) / 10))
dhdt = alpha * (1 - h) - beta * h
return dhdt

def dn(self, n, t, V):
alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
beta = 0.125 * bm.exp(-(V + 65) / 80)
dndt = alpha * (1 - n) - beta * n
return dndt
if self.noise is None:
self.integral = odeint(method=method, f=self.derivative)
else:
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

# m channel
m_alpha = lambda self, V: 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
m_beta = lambda self, V: 4.0 * bm.exp(-(V + 65) / 18)
m_inf = lambda self, V: self.m_alpha(V) / (self.m_alpha(V) + self.m_beta(V))
dm = lambda self, m, t, V: self.m_alpha(V) * (1 - m) - self.m_beta(V) * m

# h channel
h_alpha = lambda self, V: 0.07 * bm.exp(-(V + 65) / 20.)
h_beta = lambda self, V: 1 / (1 + bm.exp(-(V + 35) / 10))
h_inf = lambda self, V: self.h_alpha(V) / (self.h_alpha(V) + self.h_beta(V))
dh = lambda self, h, t, V: self.h_alpha(V) * (1 - h) - self.h_beta(V) * h

# n channel
n_alpha = lambda self, V: 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
n_beta = lambda self, V: 0.125 * bm.exp(-(V + 65) / 80)
n_inf = lambda self, V: self.n_alpha(V) / (self.n_alpha(V) + self.n_beta(V))
dn = lambda self, n, t, V: self.n_alpha(V) * (1 - n) - self.n_beta(V) * n

def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
if self._m_initializer is None:
self.m.value = self.m_inf(self.V.value)
else:
self.m.value = variable(self._m_initializer, batch_size, self.varshape)
if self._h_initializer is None:
self.h.value = self.h_inf(self.V.value)
else:
self.h.value = variable(self._h_initializer, batch_size, self.varshape)
if self._n_initializer is None:
self.n.value = self.n_inf(self.V.value)
else:
self.n.value = variable(self._n_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)

def dV(self, V, t, m, h, n, I_ext):
I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa)
@@ -277,13 +311,17 @@ class HH(NeuGroup):
def derivative(self):
return JointEq([self.dV, self.dm, self.dh, self.dn])

def update(self, t, dt):
V, m, h, n = self.integral(self.V, self.m, self.h, self.n, t, self.input, dt=dt)
def update(self, tdi, x=None):
t, dt = tdi['t'], tdi['dt']
if x is not None: self.input += x
V, m, h, n = self.integral(self.V, self.m, self.h, self.n, t, self.input, dt)
self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
self.V.value = V
self.m.value = m
self.h.value = h
self.n.value = n

def clear_input(self):
self.input[:] = 0.


@@ -356,10 +394,7 @@ class MorrisLecar(NeuGroup):
References
----------

.. [4] Meier, Stephen R., Jarrett L. Lancaster, and Joseph M. Starobin.
"Bursting regimes in a reaction-diffusion system with action
potential-dependent equilibrium." PloS one 10.3 (2015):
e0122401.
.. [4] Lecar, Harold. "Morris-lecar model." Scholarpedia 2.10 (2007): 1333.
.. [5] http://www.scholarpedia.org/article/Morris-Lecar_model
.. [6] https://en.wikipedia.org/wiki/Morris%E2%80%93Lecar_model
"""
@@ -367,41 +402,51 @@ class MorrisLecar(NeuGroup):
def __init__(
self,
size: Shape,
V_Ca: Union[float, Tensor, Initializer, Callable] = 130.,
g_Ca: Union[float, Tensor, Initializer, Callable] = 4.4,
V_K: Union[float, Tensor, Initializer, Callable] = -84.,
g_K: Union[float, Tensor, Initializer, Callable] = 8.,
V_leak: Union[float, Tensor, Initializer, Callable] = -60.,
g_leak: Union[float, Tensor, Initializer, Callable] = 2.,
C: Union[float, Tensor, Initializer, Callable] = 20.,
V1: Union[float, Tensor, Initializer, Callable] = -1.2,
V2: Union[float, Tensor, Initializer, Callable] = 18.,
V3: Union[float, Tensor, Initializer, Callable] = 2.,
V4: Union[float, Tensor, Initializer, Callable] = 30.,
phi: Union[float, Tensor, Initializer, Callable] = 0.04,
V_th: Union[float, Tensor, Initializer, Callable] = 10.,
W_initializer: Union[Callable, Initializer, Tensor] = OneInit(0.02),
V_initializer: Union[Callable, Initializer, Tensor] = Uniform(-70., -60.),
keep_size: bool = False,
V_Ca: Union[float, Array, Initializer, Callable] = 130.,
g_Ca: Union[float, Array, Initializer, Callable] = 4.4,
V_K: Union[float, Array, Initializer, Callable] = -84.,
g_K: Union[float, Array, Initializer, Callable] = 8.,
V_leak: Union[float, Array, Initializer, Callable] = -60.,
g_leak: Union[float, Array, Initializer, Callable] = 2.,
C: Union[float, Array, Initializer, Callable] = 20.,
V1: Union[float, Array, Initializer, Callable] = -1.2,
V2: Union[float, Array, Initializer, Callable] = 18.,
V3: Union[float, Array, Initializer, Callable] = 2.,
V4: Union[float, Array, Initializer, Callable] = 30.,
phi: Union[float, Array, Initializer, Callable] = 0.04,
V_th: Union[float, Array, Initializer, Callable] = 10.,
W_initializer: Union[Callable, Initializer, Array] = OneInit(0.02),
V_initializer: Union[Callable, Initializer, Array] = Uniform(-70., -60.),
noise: Union[float, Array, Initializer, Callable] = None,
method: str = 'exp_auto',
name: str = None
name: str = None,

# training parameter
mode: Mode = normal,
):
# initialization
super(MorrisLecar, self).__init__(size=size, name=name)
super(MorrisLecar, self).__init__(size=size,
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (BatchingMode, NormalMode), self.__class__)

# params
self.V_Ca = init_param(V_Ca, self.num, allow_none=False)
self.g_Ca = init_param(g_Ca, self.num, allow_none=False)
self.V_K = init_param(V_K, self.num, allow_none=False)
self.g_K = init_param(g_K, self.num, allow_none=False)
self.V_leak = init_param(V_leak, self.num, allow_none=False)
self.g_leak = init_param(g_leak, self.num, allow_none=False)
self.C = init_param(C, self.num, allow_none=False)
self.V1 = init_param(V1, self.num, allow_none=False)
self.V2 = init_param(V2, self.num, allow_none=False)
self.V3 = init_param(V3, self.num, allow_none=False)
self.V4 = init_param(V4, self.num, allow_none=False)
self.phi = init_param(phi, self.num, allow_none=False)
self.V_th = init_param(V_th, self.num, allow_none=False)
self.V_Ca = parameter(V_Ca, self.varshape, allow_none=False)
self.g_Ca = parameter(g_Ca, self.varshape, allow_none=False)
self.V_K = parameter(V_K, self.varshape, allow_none=False)
self.g_K = parameter(g_K, self.varshape, allow_none=False)
self.V_leak = parameter(V_leak, self.varshape, allow_none=False)
self.g_leak = parameter(g_leak, self.varshape, allow_none=False)
self.C = parameter(C, self.varshape, allow_none=False)
self.V1 = parameter(V1, self.varshape, allow_none=False)
self.V2 = parameter(V2, self.varshape, allow_none=False)
self.V3 = parameter(V3, self.varshape, allow_none=False)
self.V4 = parameter(V4, self.varshape, allow_none=False)
self.phi = parameter(phi, self.varshape, allow_none=False)
self.V_th = parameter(V_th, self.varshape, allow_none=False)
self.noise = init_noise(noise, self.varshape, num_vars=2)

# initializers
check_initializer(V_initializer, 'V_initializer', allow_none=False)
@@ -410,19 +455,22 @@ class MorrisLecar(NeuGroup):
self._V_initializer = V_initializer

# variables
self.W = bm.Variable(init_param(W_initializer, (self.num,)))
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.W = variable(self._W_initializer, mode, self.varshape)
self.V = variable(self._V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)

# integral
self.integral = odeint(method=method, f=self.derivative)
if self.noise is None:
self.integral = odeint(method=method, f=self.derivative)
else:
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

def reset(self):
self.W.value = init_param(self._W_initializer, (self.num,))
self.V.value = init_param(self._V_initializer, (self.num,))
self.input.value = bm.zeros(self.num)
self.spike.value = bm.zeros(self.num, dtype=bool)
def reset_state(self, batch_size=None):
self.W.value = variable(self._W_initializer, batch_size, self.varshape)
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)

def dV(self, V, t, W, I_ext):
M_inf = (1 / 2) * (1 + bm.tanh((V - self.V1) / self.V2))
@@ -442,11 +490,15 @@ class MorrisLecar(NeuGroup):
def derivative(self):
return JointEq([self.dV, self.dW])

def update(self, t, dt):
V, self.W.value = self.integral(self.V, self.W, t, self.input, dt=dt)
def update(self, tdi, x=None):
t, dt = tdi['t'], tdi['dt']
if x is not None: self.input += x
V, self.W.value = self.integral(self.V, self.W, t, self.input, dt)
spike = bm.logical_and(self.V < self.V_th, V >= self.V_th)
self.V.value = V
self.spike.value = spike

def clear_input(self):
self.input[:] = 0.


@@ -602,55 +654,63 @@ class PinskyRinzelModel(NeuGroup):
def __init__(
self,
size: Shape,
keep_size: bool = False,
# maximum conductance
gNa: Union[float, Tensor, Initializer, Callable] = 30.,
gK: Union[float, Tensor, Initializer, Callable] = 15.,
gCa: Union[float, Tensor, Initializer, Callable] = 10.,
gAHP: Union[float, Tensor, Initializer, Callable] = 0.8,
gC: Union[float, Tensor, Initializer, Callable] = 15.,
gL: Union[float, Tensor, Initializer, Callable] = 0.1,
gNa: Union[float, Array, Initializer, Callable] = 30.,
gK: Union[float, Array, Initializer, Callable] = 15.,
gCa: Union[float, Array, Initializer, Callable] = 10.,
gAHP: Union[float, Array, Initializer, Callable] = 0.8,
gC: Union[float, Array, Initializer, Callable] = 15.,
gL: Union[float, Array, Initializer, Callable] = 0.1,
# reversal potential
ENa: Union[float, Tensor, Initializer, Callable] = 60.,
EK: Union[float, Tensor, Initializer, Callable] = -75.,
ECa: Union[float, Tensor, Initializer, Callable] = 80.,
EL: Union[float, Tensor, Initializer, Callable] = -60.,
ENa: Union[float, Array, Initializer, Callable] = 60.,
EK: Union[float, Array, Initializer, Callable] = -75.,
ECa: Union[float, Array, Initializer, Callable] = 80.,
EL: Union[float, Array, Initializer, Callable] = -60.,
# other parameters
gc: Union[float, Tensor, Initializer, Callable] = 2.1,
V_th: Union[float, Tensor, Initializer, Callable] = 20.,
Cm: Union[float, Tensor, Initializer, Callable] = 3.0,
p: Union[float, Tensor, Initializer, Callable] = 0.5,
A: Union[float, Tensor, Initializer, Callable] = 1.,
gc: Union[float, Array, Initializer, Callable] = 2.1,
V_th: Union[float, Array, Initializer, Callable] = 20.,
Cm: Union[float, Array, Initializer, Callable] = 3.0,
p: Union[float, Array, Initializer, Callable] = 0.5,
A: Union[float, Array, Initializer, Callable] = 1.,
# initializers
Vs_initializer: Union[Initializer, Callable, Tensor] = OneInit(-64.6),
Vd_initializer: Union[Initializer, Callable, Tensor] = OneInit(-64.5),
Ca_initializer: Union[Initializer, Callable, Tensor] = OneInit(0.2),
Vs_initializer: Union[Initializer, Callable, Array] = OneInit(-64.6),
Vd_initializer: Union[Initializer, Callable, Array] = OneInit(-64.5),
Ca_initializer: Union[Initializer, Callable, Array] = OneInit(0.2),
# others
noise: Union[float, Array, Initializer, Callable] = None,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
# initialization
super(PinskyRinzelModel, self).__init__(size=size, name=name)
super(PinskyRinzelModel, self).__init__(size=size,
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (NormalMode, BatchingMode), self.__class__)

# conductance parameters
self.gAHP = init_param(gAHP, self.num, allow_none=False)
self.gCa = init_param(gCa, self.num, allow_none=False)
self.gNa = init_param(gNa, self.num, allow_none=False)
self.gK = init_param(gK, self.num, allow_none=False)
self.gL = init_param(gL, self.num, allow_none=False)
self.gC = init_param(gC, self.num, allow_none=False)
self.gAHP = parameter(gAHP, self.varshape, allow_none=False)
self.gCa = parameter(gCa, self.varshape, allow_none=False)
self.gNa = parameter(gNa, self.varshape, allow_none=False)
self.gK = parameter(gK, self.varshape, allow_none=False)
self.gL = parameter(gL, self.varshape, allow_none=False)
self.gC = parameter(gC, self.varshape, allow_none=False)

# reversal potential parameters
self.ENa = init_param(ENa, self.num, allow_none=False)
self.ECa = init_param(ECa, self.num, allow_none=False)
self.EK = init_param(EK, self.num, allow_none=False)
self.EL = init_param(EL, self.num, allow_none=False)
self.ENa = parameter(ENa, self.varshape, allow_none=False)
self.ECa = parameter(ECa, self.varshape, allow_none=False)
self.EK = parameter(EK, self.varshape, allow_none=False)
self.EL = parameter(EL, self.varshape, allow_none=False)

# other neuronal parameters
self.V_th = init_param(V_th, self.num, allow_none=False)
self.Cm = init_param(Cm, self.num, allow_none=False)
self.gc = init_param(gc, self.num, allow_none=False)
self.p = init_param(p, self.num, allow_none=False)
self.A = init_param(A, self.num, allow_none=False)
self.V_th = parameter(V_th, self.varshape, allow_none=False)
self.Cm = parameter(Cm, self.varshape, allow_none=False)
self.gc = parameter(gc, self.varshape, allow_none=False)
self.p = parameter(p, self.varshape, allow_none=False)
self.A = parameter(A, self.varshape, allow_none=False)
self.noise = init_noise(noise, self.varshape, num_vars=8)

# initializers
check_initializer(Vs_initializer, 'Vs_initializer', allow_none=False)
@@ -661,47 +721,56 @@ class PinskyRinzelModel(NeuGroup):
self._Ca_initializer = Ca_initializer

# variables
self.Vs = bm.Variable(init_param(self._Vs_initializer, (self.num,)))
self.Vd = bm.Variable(init_param(self._Vd_initializer, (self.num,)))
self.Ca = bm.Variable(init_param(self._Ca_initializer, (self.num,)))
self.h = bm.Variable(self.inf_h(self.Vs))
self.n = bm.Variable(self.inf_n(self.Vs))
self.s = bm.Variable(self.inf_s(self.Vd))
self.c = bm.Variable(self.inf_c(self.Vd))
self.q = bm.Variable(self.inf_q(self.Ca))
self.Id = bm.Variable(bm.zeros((self.num,))) # input to soma
self.Is = bm.Variable(bm.zeros((self.num,))) # input to dendrite
# self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.Vs = variable(self._Vs_initializer, mode, self.varshape)
self.Vd = variable(self._Vd_initializer, mode, self.varshape)
self.Ca = variable(self._Ca_initializer, mode, self.varshape)
self.h = bm.Variable(self.inf_h(self.Vs), batch_axis=0 if isinstance(mode, BatchingMode) else None)
self.n = bm.Variable(self.inf_n(self.Vs), batch_axis=0 if isinstance(mode, BatchingMode) else None)
self.s = bm.Variable(self.inf_s(self.Vd), batch_axis=0 if isinstance(mode, BatchingMode) else None)
self.c = bm.Variable(self.inf_c(self.Vd), batch_axis=0 if isinstance(mode, BatchingMode) else None)
self.q = bm.Variable(self.inf_q(self.Ca), batch_axis=0 if isinstance(mode, BatchingMode) else None)
self.Id = variable(bm.zeros, mode, self.varshape) # input to soma
self.Is = variable(bm.zeros, mode, self.varshape) # input to dendrite
# self.spike = bm.Variable(bm.zeros(self.varshape, dtype=bool))

# integral
self.integral = odeint(method=method, f=self.derivative)

def reset(self):
self.Vd.value = init_param(self._Vd_initializer, (self.num,))
self.Vs.value = init_param(self._Vs_initializer, (self.num,))
self.Ca.value = init_param(self._Ca_initializer, (self.num,))
self.h.value = self.inf_h(self.Vs)
self.n.value = self.inf_n(self.Vs)
self.s.value = self.inf_s(self.Vd)
self.c.value = self.inf_c(self.Vd)
self.q.value = self.inf_q(self.Ca)
self.Id[:] = 0
self.Is[:] = 0
if self.noise is None:
self.integral = odeint(method=method, f=self.derivative)
else:
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

def reset_state(self, batch_size=None):
self.Vd.value = variable(self._Vd_initializer, batch_size, self.varshape)
self.Vs.value = variable(self._Vs_initializer, batch_size, self.varshape)
self.Ca.value = variable(self._Ca_initializer, batch_size, self.varshape)
batch_axis = 0 if isinstance(self.mode, BatchingMode) else None
self.h.value = bm.Variable(self.inf_h(self.Vs), batch_axis=batch_axis)
self.n.value = bm.Variable(self.inf_n(self.Vs), batch_axis=batch_axis)
self.s.value = bm.Variable(self.inf_s(self.Vd), batch_axis=batch_axis)
self.c.value = bm.Variable(self.inf_c(self.Vd), batch_axis=batch_axis)
self.q.value = bm.Variable(self.inf_q(self.Ca), batch_axis=batch_axis)
self.Id.value = variable(bm.zeros, batch_size, self.varshape)
self.Is.value = variable(bm.zeros, batch_size, self.varshape)
# self.spike[:] = False

def dCa(self, Ca, t, s, Vd):
ICa = self.gCa * s * s * (Vd - self.ECa)
return -0.13 * ICa - 0.075 * Ca

def dh(self, h, t, Vs): return self.alpha_h(Vs) * (1 - h) - self.beta_h(Vs) * h
def dh(self, h, t, Vs):
return self.alpha_h(Vs) * (1 - h) - self.beta_h(Vs) * h

def dn(self, n, t, Vs): return self.alpha_n(Vs) * (1 - n) - self.beta_n(Vs) * n
def dn(self, n, t, Vs):
return self.alpha_n(Vs) * (1 - n) - self.beta_n(Vs) * n

def ds(self, s, t, Vd): return self.alpha_s(Vd) * (1 - s) - self.beta_s(Vd) * s
def ds(self, s, t, Vd):
return self.alpha_s(Vd) * (1 - s) - self.beta_s(Vd) * s

def dc(self, c, t, Vd): return self.alpha_c(Vd) * (1 - c) - self.beta_c(Vd) * c
def dc(self, c, t, Vd):
return self.alpha_c(Vd) * (1 - c) - self.beta_c(Vd) * c

def dq(self, q, t, Ca): return self.alpha_q(Ca) * (1 - q) - self.beta_q(Ca) * q
def dq(self, q, t, Ca):
return self.alpha_q(Ca) * (1 - q) - self.beta_q(Ca) * q

def dVs(self, Vs, t, h, n, Vd):
I_Na = (self.gNa * self.inf_m(Vs) ** 2 * h) * (Vs - self.ENa)
@@ -725,7 +794,8 @@ class PinskyRinzelModel(NeuGroup):
def derivative(self):
return JointEq([self.dVs, self.dVd, self.dCa, self.dh, self.dn, self.ds, self.dc, self.dq])

def update(self, t, dt):
def update(self, tdi, x=None):
assert x is None
Vs, Vd, Ca, h, n, s, c, q = self.integral(Vs=self.Vs.value,
Vd=self.Vd.value,
Ca=self.Ca.value,
@@ -734,8 +804,8 @@ class PinskyRinzelModel(NeuGroup):
s=self.s.value,
c=self.c.value,
q=self.q.value,
t=t,
dt=dt)
t=tdi['t'],
dt=tdi['dt'])
self.Vs.value = Vs
self.Vd.value = Vd
self.Ca.value = Ca
@@ -744,39 +814,49 @@ class PinskyRinzelModel(NeuGroup):
self.s.value = s
self.c.value = c
self.q.value = q

def clear_input(self):
self.Id[:] = 0.
self.Is[:] = 0.

def alpha_m(self, Vs): return 0.32 * (13.1 - (Vs + 60.)) / (bm.exp((13.1 - (Vs + 60.)) / 4.) - 1.)
def alpha_m(self, Vs):
return 0.32 * (13.1 - (Vs + 60.)) / (bm.exp((13.1 - (Vs + 60.)) / 4.) - 1.)

def beta_m(self, Vs): return 0.28 * ((Vs + 60.) - 40.1) / (bm.exp(((Vs + 60.) - 40.1) / 5.) - 1.)
def beta_m(self, Vs):
return 0.28 * ((Vs + 60.) - 40.1) / (bm.exp(((Vs + 60.) - 40.1) / 5.) - 1.)

def inf_m(self, Vs):
alpha = self.alpha_m(Vs)
beta = self.beta_m(Vs)
return alpha / (alpha + beta)

def alpha_n(self, Vs): return 0.016 * (35.1 - (Vs + 60.)) / (bm.exp((35.1 - (Vs + 60.)) / 5) - 1)
def alpha_n(self, Vs):
return 0.016 * (35.1 - (Vs + 60.)) / (bm.exp((35.1 - (Vs + 60.)) / 5) - 1)

def beta_n(self, Vs): return 0.25 * bm.exp(0.5 - 0.025 * (Vs + 60.))
def beta_n(self, Vs):
return 0.25 * bm.exp(0.5 - 0.025 * (Vs + 60.))

def inf_n(self, Vs):
alpha = self.alpha_n(Vs)
beta = self.beta_n(Vs)
return alpha / (alpha + beta)

def alpha_h(self, Vs): return 0.128 * bm.exp((17. - (Vs + 60.)) / 18.)
def alpha_h(self, Vs):
return 0.128 * bm.exp((17. - (Vs + 60.)) / 18.)

def beta_h(self, Vs): return 4. / (1 + bm.exp((40. - (Vs + 60.)) / 5))
def beta_h(self, Vs):
return 4. / (1 + bm.exp((40. - (Vs + 60.)) / 5))

def inf_h(self, Vs):
alpha = self.alpha_h(Vs)
beta = self.beta_h(Vs)
return alpha / (alpha + beta)

def alpha_s(self, Vd): return 1.6 / (1 + bm.exp(-0.072 * ((Vd + 60.) - 65.)))
def alpha_s(self, Vd):
return 1.6 / (1 + bm.exp(-0.072 * ((Vd + 60.) - 65.)))

def beta_s(self, Vd): return 0.02 * ((Vd + 60.) - 51.1) / (bm.exp(((Vd + 60.) - 51.1) / 5.) - 1.)
def beta_s(self, Vd):
return 0.02 * ((Vd + 60.) - 51.1) / (bm.exp(((Vd + 60.) - 51.1) / 5.) - 1.)

def inf_s(self, Vd):
alpha = self.alpha_s(Vd)
@@ -797,11 +877,204 @@ class PinskyRinzelModel(NeuGroup):
beta_c = self.beta_c(Vd)
return alpha_c / (alpha_c + beta_c)

def alpha_q(self, Ca): return bm.minimum(2e-5 * Ca, 1e-2)
def alpha_q(self, Ca):
return bm.minimum(2e-5 * Ca, 1e-2)

def beta_q(self, Ca): return 1e-3
def beta_q(self, Ca):
return 1e-3

def inf_q(self, Ca):
alpha = self.alpha_q(Ca)
beta = self.beta_q(Ca)
return alpha / (alpha + beta)


class WangBuzsakiModel(NeuGroup):
r"""Wang-Buzsaki model [9]_, an implementation of a modified Hodgkin-Huxley model.

Each model is described by a single compartment and obeys the current balance equation:

.. math::

C_{m} \frac{d V}{d t}=-I_{\mathrm{Na}}-I_{\mathrm{K}}-I_{\mathrm{L}}-I_{\mathrm{syn}}+I_{\mathrm{app}}

where :math:`C_{m}=1 \mu \mathrm{F} / \mathrm{cm}^{2}` and :math:`I_{\mathrm{app}}` is the
injected current (in :math:`\mu \mathrm{A} / \mathrm{cm}^{2}` ). The leak current
:math:`I_{\mathrm{L}}=g_{\mathrm{L}}\left(V-E_{\mathrm{L}}\right)` has a conductance
:math:`g_{\mathrm{L}}=0.1 \mathrm{mS} / \mathrm{cm}^{2}`, so that the passive time constant
:math:`\tau_{0}=C_{m} / g_{\mathrm{L}}=10 \mathrm{msec} ; E_{\mathrm{L}}=-65 \mathrm{mV}`.

The spike-generating :math:`\mathrm{Na}^{+}` and :math:`\mathrm{K}^{+}` voltage-dependent ion
currents :math:`\left(I_{\mathrm{Na}}\right.` and :math:`I_{\mathrm{K}}` ) are of the
Hodgkin-Huxley type (Hodgkin and Huxley, 1952). The transient sodium current
:math:`I_{\mathrm{Na}}=g_{\mathrm{Na}} m_{\infty}^{3} h\left(V-E_{\mathrm{Na}}\right)`,
where the activation variable :math:`m` is assumed fast and substituted by its steady-state
function :math:`m_{\infty}=\alpha_{m} /\left(\alpha_{m}+\beta_{m}\right)` ;
:math:`\alpha_{m}(V)=-0.1(V+35) /(\exp (-0.1(V+35))-1), \beta_{m}(V)=4 \exp (-(V+60) / 18)`.
The inactivation variable :math:`h` obeys a first-order kinetics:

.. math::

\frac{d h}{d t}=\phi\left(\alpha_{h}(1-h)-\beta_{h} h\right)

where :math:`\alpha_{h}(V)=0.07 \exp (-(V+58) / 20)` and
:math:`\beta_{h}(V)=1 /(\exp (-0.1(V+28)) +1) \cdot g_{\mathrm{Na}}=35 \mathrm{mS} / \mathrm{cm}^{2}` ;
:math:`E_{\mathrm{Na}}=55 \mathrm{mV}, \phi=5 .`

The delayed rectifier :math:`I_{\mathrm{K}}=g_{\mathrm{K}} n^{4}\left(V-E_{\mathrm{K}}\right)`,
where the activation variable :math:`n` obeys the following equation:

.. math::

\frac{d n}{d t}=\phi\left(\alpha_{n}(1-n)-\beta_{n} n\right)

with :math:`\alpha_{n}(V)=-0.01(V+34) /(\exp (-0.1(V+34))-1)` and
:math:`\beta_{n}(V)=0.125\exp (-(V+44) / 80)` ; :math:`g_{\mathrm{K}}=9 \mathrm{mS} / \mathrm{cm}^{2}`, and
:math:`E_{\mathrm{K}}=-90 \mathrm{mV}`.


Parameters
----------
size: sequence of int, int
The size of the neuron group.
ENa: float, JaxArray, ndarray, Initializer, callable
The reversal potential of sodium. Default is 50 mV.
gNa: float, JaxArray, ndarray, Initializer, callable
The maximum conductance of sodium channel. Default is 120 msiemens.
EK: float, JaxArray, ndarray, Initializer, callable
The reversal potential of potassium. Default is -77 mV.
gK: float, JaxArray, ndarray, Initializer, callable
The maximum conductance of potassium channel. Default is 36 msiemens.
EL: float, JaxArray, ndarray, Initializer, callable
The reversal potential of learky channel. Default is -54.387 mV.
gL: float, JaxArray, ndarray, Initializer, callable
The conductance of learky channel. Default is 0.03 msiemens.
V_th: float, JaxArray, ndarray, Initializer, callable
The threshold of the membrane spike. Default is 20 mV.
C: float, JaxArray, ndarray, Initializer, callable
The membrane capacitance. Default is 1 ufarad.
phi: float, JaxArray, ndarray, Initializer, callable
The temperature regulator constant.
V_initializer: JaxArray, ndarray, Initializer, callable
The initializer of membrane potential.
h_initializer: JaxArray, ndarray, Initializer, callable
The initializer of h channel.
n_initializer: JaxArray, ndarray, Initializer, callable
The initializer of n channel.
method: str
The numerical integration method.
name: str
The group name.

References
----------
.. [9] Wang, X.J. and Buzsaki, G., (1996) Gamma oscillation by synaptic
inhibition in a hippocampal interneuronal network model. Journal of
neuroscience, 16(20), pp.6402-6413.

"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
ENa: Union[float, Array, Initializer, Callable] = 55.,
gNa: Union[float, Array, Initializer, Callable] = 35.,
EK: Union[float, Array, Initializer, Callable] = -90.,
gK: Union[float, Array, Initializer, Callable] = 9.,
EL: Union[float, Array, Initializer, Callable] = -65,
gL: Union[float, Array, Initializer, Callable] = 0.1,
V_th: Union[float, Array, Initializer, Callable] = 20.,
phi: Union[float, Array, Initializer, Callable] = 5.0,
C: Union[float, Array, Initializer, Callable] = 1.0,
V_initializer: Union[Initializer, Callable, Array] = OneInit(-65.),
h_initializer: Union[Initializer, Callable, Array] = OneInit(0.6),
n_initializer: Union[Initializer, Callable, Array] = OneInit(0.32),
noise: Union[float, Array, Initializer, Callable] = None,
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
):
# initialization
super(WangBuzsakiModel, self).__init__(size=size, keep_size=keep_size, name=name, mode=mode)
check(self.mode, (BatchingMode, NormalMode), self.__class__)

# parameters
self.ENa = parameter(ENa, self.varshape, allow_none=False)
self.EK = parameter(EK, self.varshape, allow_none=False)
self.EL = parameter(EL, self.varshape, allow_none=False)
self.gNa = parameter(gNa, self.varshape, allow_none=False)
self.gK = parameter(gK, self.varshape, allow_none=False)
self.gL = parameter(gL, self.varshape, allow_none=False)
self.C = parameter(C, self.varshape, allow_none=False)
self.phi = parameter(phi, self.varshape, allow_none=False)
self.V_th = parameter(V_th, self.varshape, allow_none=False)
self.noise = init_noise(noise, self.varshape, num_vars=3)

# initializers
check_initializer(h_initializer, 'h_initializer', allow_none=False)
check_initializer(n_initializer, 'n_initializer', allow_none=False)
check_initializer(V_initializer, 'V_initializer', allow_none=False)
self._h_initializer = h_initializer
self._n_initializer = n_initializer
self._V_initializer = V_initializer

# variables
self.h = variable(self._h_initializer, mode, self.varshape)
self.n = variable(self._n_initializer, mode, self.varshape)
self.V = variable(self._V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)

# integral
if self.noise is None:
self.integral = odeint(method=method, f=self.derivative)
else:
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

def reset_state(self, batch_size=None):
self.h.value = variable(self._h_initializer, batch_size, self.varshape)
self.n.value = variable(self._n_initializer, batch_size, self.varshape)
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)

def m_inf(self, V):
alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1)
beta = 4. * bm.exp(-(V + 60.) / 18.)
return alpha / (alpha + beta)

def dh(self, h, t, V):
alpha = 0.07 * bm.exp(-(V + 58) / 20)
beta = 1 / (bm.exp(-0.1 * (V + 28)) + 1)
dhdt = alpha * (1 - h) - beta * h
return self.phi * dhdt

def dn(self, n, t, V):
alpha = -0.01 * (V + 34) / (bm.exp(-0.1 * (V + 34)) - 1)
beta = 0.125 * bm.exp(-(V + 44) / 80)
dndt = alpha * (1 - n) - beta * n
return self.phi * dndt

def dV(self, V, t, h, n, I_ext):
INa = self.gNa * self.m_inf(V) ** 3 * h * (V - self.ENa)
IK = self.gK * n ** 4 * (V - self.EK)
IL = self.gL * (V - self.EL)
dVdt = (- INa - IK - IL + I_ext) / self.C
return dVdt

@property
def derivative(self):
return JointEq([self.dV, self.dh, self.dn])

def update(self, tdi, x=None):
t, dt = tdi['t'], tdi['dt']
if x is not None: self.input += x
V, h, n = self.integral(self.V, self.h, self.n, t, self.input, dt)
self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
self.V.value = V
self.h.value = h
self.n.value = n

def clear_input(self):
self.input[:] = 0.

+ 16
- 0
brainpy/dyn/neurons/compat.py View File

@@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-


from .biological_models import HH, MorrisLecar, PinskyRinzelModel
from .fractional_models import FractionalFHR, FractionalIzhikevich
from .reduced_models import LIF, ExpIF, AdExIF, QuaIF, AdQuaIF, GIF, Izhikevich, HindmarshRose, FHN
from .input_groups import SpikeTimeGroup, PoissonGroup
from .noise_groups import OUProcess

__all__ = [
'HH', 'MorrisLecar', 'PinskyRinzelModel',
'FractionalFHR', 'FractionalIzhikevich',
'LIF', 'ExpIF', 'AdExIF', 'QuaIF', 'AdQuaIF',
'GIF', 'Izhikevich', 'HindmarshRose', 'FHN',
'SpikeTimeGroup', 'PoissonGroup', 'OUProcess'
]

+ 76
- 65
brainpy/dyn/neurons/fractional_models.py View File

@@ -4,13 +4,13 @@ from typing import Union, Sequence, Callable

import brainpy.math as bm
from brainpy.dyn.base import NeuGroup
from brainpy.initialize import ZeroInit, OneInit, Initializer, init_param
from brainpy.initialize import ZeroInit, OneInit, Initializer, parameter
from brainpy.integrators.fde import CaputoL1Schema
from brainpy.integrators.fde import GLShortMemory
from brainpy.integrators.joint_eq import JointEq
from brainpy.tools.checking import check_float, check_integer
from brainpy.tools.checking import check_initializer
from brainpy.types import Shape, Tensor
from brainpy.types import Shape, Array

__all__ = [
'FractionalNeuron',
@@ -83,32 +83,33 @@ class FractionalFHR(FractionalNeuron):
size: Shape,
alpha: Union[float, Sequence[float]],
num_memory: int = 1000,
a: Union[float, Tensor, Initializer, Callable] = 0.7,
b: Union[float, Tensor, Initializer, Callable] = 0.8,
c: Union[float, Tensor, Initializer, Callable] = -0.775,
d: Union[float, Tensor, Initializer, Callable] = 1.,
delta: Union[float, Tensor, Initializer, Callable] = 0.08,
mu: Union[float, Tensor, Initializer, Callable] = 0.0001,
Vth: Union[float, Tensor, Initializer, Callable] = 1.8,
V_initializer: Union[Initializer, Callable, Tensor] = OneInit(2.5),
w_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
y_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
name: str = None
a: Union[float, Array, Initializer, Callable] = 0.7,
b: Union[float, Array, Initializer, Callable] = 0.8,
c: Union[float, Array, Initializer, Callable] = -0.775,
d: Union[float, Array, Initializer, Callable] = 1.,
delta: Union[float, Array, Initializer, Callable] = 0.08,
mu: Union[float, Array, Initializer, Callable] = 0.0001,
Vth: Union[float, Array, Initializer, Callable] = 1.8,
V_initializer: Union[Initializer, Callable, Array] = OneInit(2.5),
w_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
y_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
name: str = None,
keep_size: bool = False,
):
super(FractionalFHR, self).__init__(size, name=name)
super(FractionalFHR, self).__init__(size, keep_size=keep_size, name=name)

# fractional order
self.alpha = alpha
check_integer(num_memory, 'num_memory', allow_none=False)

# parameters
self.a = init_param(a, self.num, allow_none=False)
self.b = init_param(b, self.num, allow_none=False)
self.c = init_param(c, self.num, allow_none=False)
self.d = init_param(d, self.num, allow_none=False)
self.mu = init_param(mu, self.num, allow_none=False)
self.Vth = init_param(Vth, self.num, allow_none=False)
self.delta = init_param(delta, self.num, allow_none=False)
self.a = parameter(a, self.varshape, allow_none=False)
self.b = parameter(b, self.varshape, allow_none=False)
self.c = parameter(c, self.varshape, allow_none=False)
self.d = parameter(d, self.varshape, allow_none=False)
self.mu = parameter(mu, self.varshape, allow_none=False)
self.Vth = parameter(Vth, self.varshape, allow_none=False)
self.delta = parameter(delta, self.varshape, allow_none=False)

# initializers
check_initializer(V_initializer, 'V_initializer', allow_none=False)
@@ -119,11 +120,11 @@ class FractionalFHR(FractionalNeuron):
self._y_initializer = y_initializer

# variables
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.w = bm.Variable(init_param(w_initializer, (self.num,)))
self.y = bm.Variable(init_param(y_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.V = bm.Variable(parameter(V_initializer, self.varshape))
self.w = bm.Variable(parameter(w_initializer, self.varshape))
self.y = bm.Variable(parameter(y_initializer, self.varshape))
self.input = bm.Variable(bm.zeros(self.varshape))
self.spike = bm.Variable(bm.zeros(self.varshape, dtype=bool))

# integral function
self.integral = GLShortMemory(self.derivative,
@@ -131,10 +132,11 @@ class FractionalFHR(FractionalNeuron):
num_memory=num_memory,
inits=[self.V, self.w, self.y])

def reset(self):
self.V.value = init_param(self._V_initializer, (self.num,))
self.w.value = init_param(self._w_initializer, (self.num,))
self.y.value = init_param(self._y_initializer, (self.num,))
def reset_state(self, batch_size=None):
assert batch_size is None
self.V.value = parameter(self._V_initializer, self.varshape)
self.w.value = parameter(self._w_initializer, self.varshape)
self.y.value = parameter(self._y_initializer, self.varshape)
self.input[:] = 0
self.spike[:] = False
# integral function reset
@@ -153,12 +155,16 @@ class FractionalFHR(FractionalNeuron):
def derivative(self):
return JointEq([self.dV, self.dw, self.dy])

def update(self, t, dt):
def update(self, tdi, x=None):
t, dt = tdi['t'], tdi['dt']
if x is not None: self.input += x
V, w, y = self.integral(self.V, self.w, self.y, t, dt)
self.spike.value = bm.logical_and(V >= self.Vth, self.V < self.Vth)
self.V.value = V
self.w.value = w
self.y.value = y

def clear_input(self):
self.input[:] = 0.


@@ -220,37 +226,38 @@ class FractionalIzhikevich(FractionalNeuron):
self,
size: Shape,
alpha: Union[float, Sequence[float]],
num_step: int,
a: Union[float, Tensor, Initializer, Callable] = 0.02,
b: Union[float, Tensor, Initializer, Callable] = 0.20,
c: Union[float, Tensor, Initializer, Callable] = -65.,
d: Union[float, Tensor, Initializer, Callable] = 8.,
f: Union[float, Tensor, Initializer, Callable] = 0.04,
g: Union[float, Tensor, Initializer, Callable] = 5.,
h: Union[float, Tensor, Initializer, Callable] = 140.,
R: Union[float, Tensor, Initializer, Callable] = 1.,
tau: Union[float, Tensor, Initializer, Callable] = 1.,
V_th: Union[float, Tensor, Initializer, Callable] = 30.,
V_initializer: Union[Initializer, Callable, Tensor] = OneInit(-65.),
u_initializer: Union[Initializer, Callable, Tensor] = OneInit(0.20 * -65.),
num_memory: int,
a: Union[float, Array, Initializer, Callable] = 0.02,
b: Union[float, Array, Initializer, Callable] = 0.20,
c: Union[float, Array, Initializer, Callable] = -65.,
d: Union[float, Array, Initializer, Callable] = 8.,
f: Union[float, Array, Initializer, Callable] = 0.04,
g: Union[float, Array, Initializer, Callable] = 5.,
h: Union[float, Array, Initializer, Callable] = 140.,
R: Union[float, Array, Initializer, Callable] = 1.,
tau: Union[float, Array, Initializer, Callable] = 1.,
V_th: Union[float, Array, Initializer, Callable] = 30.,
V_initializer: Union[Initializer, Callable, Array] = OneInit(-65.),
u_initializer: Union[Initializer, Callable, Array] = OneInit(0.20 * -65.),
keep_size: bool = False,
name: str = None
):
# initialization
super(FractionalIzhikevich, self).__init__(size=size, name=name)
super(FractionalIzhikevich, self).__init__(size=size, keep_size=keep_size, name=name)

# params
self.alpha = alpha
check_float(alpha, 'alpha', min_bound=0., max_bound=1., allow_none=False, allow_int=True)
self.a = init_param(a, self.num, allow_none=False)
self.b = init_param(b, self.num, allow_none=False)
self.c = init_param(c, self.num, allow_none=False)
self.d = init_param(d, self.num, allow_none=False)
self.f = init_param(f, self.num, allow_none=False)
self.g = init_param(g, self.num, allow_none=False)
self.h = init_param(h, self.num, allow_none=False)
self.tau = init_param(tau, self.num, allow_none=False)
self.R = init_param(R, self.num, allow_none=False)
self.V_th = init_param(V_th, self.num, allow_none=False)
self.a = parameter(a, self.varshape, allow_none=False)
self.b = parameter(b, self.varshape, allow_none=False)
self.c = parameter(c, self.varshape, allow_none=False)
self.d = parameter(d, self.varshape, allow_none=False)
self.f = parameter(f, self.varshape, allow_none=False)
self.g = parameter(g, self.varshape, allow_none=False)
self.h = parameter(h, self.varshape, allow_none=False)
self.tau = parameter(tau, self.varshape, allow_none=False)
self.R = parameter(R, self.varshape, allow_none=False)
self.V_th = parameter(V_th, self.varshape, allow_none=False)

# initializers
check_initializer(V_initializer, 'V_initializer', allow_none=False)
@@ -259,21 +266,21 @@ class FractionalIzhikevich(FractionalNeuron):
self._u_initializer = u_initializer

# variables
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.u = bm.Variable(init_param(u_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.V = bm.Variable(parameter(V_initializer, self.varshape))
self.u = bm.Variable(parameter(u_initializer, self.varshape))
self.input = bm.Variable(bm.zeros(self.varshape))
self.spike = bm.Variable(bm.zeros(self.varshape, dtype=bool))

# functions
check_integer(num_step, 'num_step', allow_none=False)
check_integer(num_memory, 'num_step', allow_none=False)
self.integral = CaputoL1Schema(f=self.derivative,
alpha=alpha,
num_step=num_step,
num_memory=num_memory,
inits=[self.V, self.u])

def reset(self):
self.V.value = init_param(self._V_initializer, (self.num,))
self.u.value = init_param(self._u_initializer, (self.num,))
def reset_state(self, batch_size=None):
self.V.value = parameter(self._V_initializer, self.varshape)
self.u.value = parameter(self._u_initializer, self.varshape)
self.input[:] = 0
self.spike[:] = False
# integral function reset
@@ -291,10 +298,14 @@ class FractionalIzhikevich(FractionalNeuron):
def derivative(self):
return JointEq([self.dV, self.du])

def update(self, t, dt):
def update(self, tdi, x=None):
t, dt = tdi['t'], tdi['dt']
if x is not None: self.input += x
V, u = self.integral(self.V, self.u, t=t, I_ext=self.input, dt=dt)
spikes = V >= self.V_th
self.V.value = bm.where(spikes, self.c, V)
self.u.value = bm.where(spikes, u + self.d, u)
self.spike.value = spikes

def clear_input(self):
self.input[:] = 0.

+ 207
- 0
brainpy/dyn/neurons/input_groups.py View File

@@ -0,0 +1,207 @@
# -*- coding: utf-8 -*-

from typing import Union, Sequence

import jax.numpy as jnp

import brainpy.math as bm
from brainpy.dyn.base import NeuGroup
from brainpy.errors import ModelBuildError
from brainpy.initialize import Initializer, parameter, variable
from brainpy.modes import Mode, BatchingMode, normal
from brainpy.types import Shape, Array

__all__ = [
'InputGroup',
'OutputGroup',
'SpikeTimeGroup',
'PoissonGroup',
]


class InputGroup(NeuGroup):
"""Input neuron group for place holder.

Parameters
----------
size: int, tuple of int
keep_size: bool
mode: Mode
name: str
"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
mode: Mode = normal,
name: str = None,
):
super(InputGroup, self).__init__(name=name,
size=size,
keep_size=keep_size,
mode=mode)
self.spike = None

def update(self, tdi, x=None):
pass

def reset_state(self, batch_size=None):
pass


class OutputGroup(NeuGroup):
"""Output neuron group for place holder.

Parameters
----------
size: int, tuple of int
keep_size: bool
mode: Mode
name: str
"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
mode: Mode = normal,
name: str = None,
):
super(OutputGroup, self).__init__(name=name,
size=size,
keep_size=keep_size,
mode=mode)
self.spike = None

def update(self, tdi, x=None):
pass

def reset_state(self, batch_size=None):
pass


class SpikeTimeGroup(NeuGroup):
"""The input neuron group characterized by spikes emitting at given times.

>>> # Get 2 neurons, firing spikes at 10 ms and 20 ms.
>>> SpikeTimeGroup(2, times=[10, 20])
>>> # or
>>> # Get 2 neurons, the neuron 0 fires spikes at 10 ms and 20 ms.
>>> SpikeTimeGroup(2, times=[10, 20], indices=[0, 0])
>>> # or
>>> # Get 2 neurons, neuron 0 fires at 10 ms and 30 ms, neuron 1 fires at 20 ms.
>>> SpikeTimeGroup(2, times=[10, 20, 30], indices=[0, 1, 0])
>>> # or
>>> # Get 2 neurons; at 10 ms, neuron 0 fires; at 20 ms, neuron 0 and 1 fire;
>>> # at 30 ms, neuron 1 fires.
>>> SpikeTimeGroup(2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1])

Parameters
----------
size : int, tuple, list
The neuron group geometry.
indices : list, tuple, np.ndarray, JaxArray, jax.numpy.ndarray
The neuron indices at each time point to emit spikes.
times : list, tuple, np.ndarray, JaxArray, jax.numpy.ndarray
The time points which generate the spikes.
name : str, optional
The name of the dynamic system.
"""

def __init__(
self,
size: Shape,
times: Union[Sequence, Array],
indices: Union[Sequence, Array],
need_sort: bool = True,
keep_size: bool = False,
mode: Mode = normal,
name: str = None
):
super(SpikeTimeGroup, self).__init__(size=size,
name=name,
keep_size=keep_size,
mode=mode)

# parameters
if keep_size:
raise NotImplementedError(f'Do not support keep_size=True in {self.__class__.__name__}')
if len(indices) != len(times):
raise ModelBuildError(f'The length of "indices" and "times" must be the same. '
f'However, we got {len(indices)} != {len(times)}.')
self.num_times = len(times)

# data about times and indices
self.times = bm.asarray(times)
self.indices = bm.asarray(indices, dtype=bm.ditype())

# variables
self.i = bm.Variable(bm.zeros(1, dtype=bm.ditype()))
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
if need_sort:
sort_idx = bm.argsort(self.times)
self.indices.value = self.indices[sort_idx]
self.times.value = self.times[sort_idx]

# functions
def cond_fun(t):
i = self.i[0]
return bm.logical_and(i < self.num_times, t >= self.times[i])

def body_fun(t):
i = self.i[0]
if isinstance(self.mode, BatchingMode):
self.spike[:, self.indices[i]] = True
else:
self.spike[self.indices[i]] = True
self.i += 1

self._run = bm.make_while(cond_fun, body_fun, dyn_vars=self.vars())

def reset_state(self, batch_size=None):
self.i[0] = 1
self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)

def update(self, tdi, x=None):
self.spike[:] = False
self._run(tdi['t'])


class PoissonGroup(NeuGroup):
"""Poisson Neuron Group.
"""

def __init__(
self,
size: Shape,
freqs: Union[int, float, jnp.ndarray, bm.JaxArray, Initializer],
seed: int = None,
keep_size: bool = False,
mode: Mode = normal,
name: str = None
):
super(PoissonGroup, self).__init__(size=size,
name=name,
keep_size=keep_size,
mode=mode)

# parameters
self.keep_size = keep_size
self.seed = seed
self.freqs = parameter(freqs, self.num, allow_none=False)

# variables
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
self.rng = bm.random.RandomState(seed=seed)

def update(self, tdi, x=None):
shape = (self.spike.shape[:1] + self.varshape) if isinstance(self.mode, BatchingMode) else self.varshape
self.spike.update(self.rng.random(shape) <= (self.freqs * tdi['dt'] / 1000.))

def reset(self, batch_size=None):
self.rng.seed(self.seed)
self.reset_state(batch_size)

def reset_state(self, batch_size=None):
self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)

+ 0
- 159
brainpy/dyn/neurons/input_models.py View File

@@ -1,159 +0,0 @@
# -*- coding: utf-8 -*-

import warnings
from typing import Union

import jax.numpy as jnp

import brainpy.math as bm
from brainpy.dyn.base import NeuGroup
from brainpy.errors import ModelBuildError
from brainpy.initialize import Initializer, init_param
from brainpy.types import Shape

__all__ = [
'SpikeTimeInput',
'PoissonInput',
'SpikeTimeGroup',
'PoissonGroup',
]


class SpikeTimeGroup(NeuGroup):
"""The input neuron group characterized by spikes emitting at given times.

>>> # Get 2 neurons, firing spikes at 10 ms and 20 ms.
>>> SpikeTimeGroup(2, times=[10, 20])
>>> # or
>>> # Get 2 neurons, the neuron 0 fires spikes at 10 ms and 20 ms.
>>> SpikeTimeGroup(2, times=[10, 20], indices=[0, 0])
>>> # or
>>> # Get 2 neurons, neuron 0 fires at 10 ms and 30 ms, neuron 1 fires at 20 ms.
>>> SpikeTimeGroup(2, times=[10, 20, 30], indices=[0, 1, 0])
>>> # or
>>> # Get 2 neurons; at 10 ms, neuron 0 fires; at 20 ms, neuron 0 and 1 fire;
>>> # at 30 ms, neuron 1 fires.
>>> SpikeTimeGroup(2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1])

Parameters
----------
size : int, tuple, list
The neuron group geometry.
indices : list, tuple, np.ndarray, JaxArray, jax.numpy.ndarray
The neuron indices at each time point to emit spikes.
times : list, tuple, np.ndarray, JaxArray, jax.numpy.ndarray
The time points which generate the spikes.
name : str, optional
The name of the dynamic system.
"""

def __init__(
self,
size: Shape,
times,
indices,
need_sort: bool = True,
name: str = None
):
super(SpikeTimeGroup, self).__init__(size=size, name=name)

# parameters
if len(indices) != len(times):
raise ModelBuildError(f'The length of "indices" and "times" must be the same. '
f'However, we got {len(indices)} != {len(times)}.')
self.num_times = len(times)

# data about times and indices
self.times = bm.asarray(times, dtype=bm.float_)
self.indices = bm.asarray(indices, dtype=bm.int_)

# variables
self.i = bm.Variable(bm.zeros(1, dtype=bm.int_))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
if need_sort:
sort_idx = bm.argsort(self.times)
self.indices.value = self.indices[sort_idx]
self.times.value = self.times[sort_idx]

# functions
def cond_fun(t):
return bm.logical_and(self.i[0] < self.num_times, t >= self.times[self.i[0]])

def body_fun(t):
self.spike[self.indices[self.i[0]]] = True
self.i[0] += 1

self._run = bm.make_while(cond_fun, body_fun, dyn_vars=self.vars())

def reset(self):
self.i[0] = 1
self.spike[:] = False

def update(self, t, _i, **kwargs):
self.spike[:] = False
self._run(t)


def SpikeTimeInput(*args, **kwargs):
"""Spike Time Input.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.SpikeTimeGroup" instead.

Returns
-------
group: NeuGroup
The neural group.
"""
warnings.warn('Please use "brainpy.dyn.SpikeTimeGroup" instead. '
'"brainpy.dyn.SpikeTimeInput" is deprecated since '
'version 2.1.5', DeprecationWarning)
return SpikeTimeGroup(*args, **kwargs)


class PoissonGroup(NeuGroup):
"""Poisson Neuron Group.
"""

def __init__(
self,
size: Shape,
freqs: Union[float, jnp.ndarray, bm.JaxArray, Initializer],
seed: int = None,
name: str = None
):
super(PoissonGroup, self).__init__(size=size, name=name)

# parameters
self.seed = seed
self.freqs = init_param(freqs, self.num, allow_none=False)
self.dt = bm.get_dt() / 1000.
self.size = (size,) if isinstance(size, int) else tuple(size)

# variables
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.rng = bm.random.RandomState(seed=seed)

def update(self, t, _i):
self.spike.update(self.rng.random(self.num) <= self.freqs * self.dt)

def reset(self):
self.spike[:] = False
self.rng.seed(self.seed)


def PoissonInput(*args, **kwargs):
"""Poisson Group Input.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.PoissonGroup" instead.

Returns
-------
poisson_group: NeuGroup
The poisson neural group.
"""
warnings.warn('Please use "brainpy.dyn.PoissonGroup" instead. '
'"brainpy.dyn.PoissonInput" is deprecated since '
'version 2.1.5', DeprecationWarning)
return PoissonGroup(*args, **kwargs)

brainpy/dyn/rates/noises.py → brainpy/dyn/neurons/noise_groups.py View File

@@ -2,11 +2,12 @@

from typing import Union, Callable

import brainpy.math as bm
from brainpy import math as bm, initialize as init
from brainpy.dyn.base import NeuGroup
from brainpy.initialize import init_param, Initializer
from brainpy.initialize import Initializer
from brainpy.integrators.sde import sdeint
from brainpy.types import Tensor, Shape
from brainpy.modes import Mode, normal
from brainpy.types import Array, Shape

__all__ = [
'OUProcess',
@@ -45,34 +46,36 @@ class OUProcess(NeuGroup):
def __init__(
self,
size: Shape,
mean: Union[float, Tensor, Initializer, Callable] = 0.,
sigma: Union[float, Tensor, Initializer, Callable] = 1.,
tau: Union[float, Tensor, Initializer, Callable] = 10.,
method: str = 'euler',
name: str = None
mean: Union[float, Array, Initializer, Callable] = 0.,
sigma: Union[float, Array, Initializer, Callable] = 1.,
tau: Union[float, Array, Initializer, Callable] = 10.,
method: str = 'exp_euler',
keep_size: bool = False,
mode: Mode = normal,
name: str = None,
):
super(OUProcess, self).__init__(size=size, name=name)
super(OUProcess, self).__init__(size=size, name=name, keep_size=keep_size, mode=mode)


# parameters
self.mean = init_param(mean, self.num, allow_none=False)
self.sigma = init_param(sigma, self.num, allow_none=False)
self.tau = init_param(tau, self.num, allow_none=False)
self.mean = init.parameter(mean, self.varshape, allow_none=False)
self.sigma = init.parameter(sigma, self.varshape, allow_none=False)
self.tau = init.parameter(tau, self.varshape, allow_none=False)

# variables
self.x = bm.Variable(bm.ones(self.num) * mean)
self.x = init.variable(lambda s: bm.ones(s) * self.mean, mode, self.varshape)

# integral functions
self.integral = sdeint(f=self.df, g=self.dg, method=method)

def reset(self):
self.x[:] = self.mean
def reset_state(self, batch_size=None):
self.x.value = init.variable(lambda s: bm.ones(s) * self.mean, batch_size, self.varshape)

def df(self, x, t):
f_x_ou = (self.mean - x) / self.tau
return f_x_ou
return (self.mean - x) / self.tau

def dg(self, x, t):
return self.sigma

def update(self, t, dt):
self.x.value = self.integral(self.x, t, dt)
def update(self, tdi):
self.x.value = self.integral(self.x, tdi['t'], tdi['dt'])

+ 897
- 362
brainpy/dyn/neurons/reduced_models.py View File

@@ -2,26 +2,128 @@

from typing import Union, Callable

from jax.lax import stop_gradient

import brainpy.math as bm
from brainpy.dyn.base import NeuGroup
from brainpy.initialize import ZeroInit, OneInit, Initializer, init_param
from brainpy.initialize import (ZeroInit, OneInit, Initializer,
parameter, variable, noise as init_noise)
from brainpy.integrators import sdeint, odeint, JointEq
from brainpy.tools.checking import check_initializer
from brainpy.types import Shape, Tensor
from brainpy.modes import Mode, NormalMode, BatchingMode, TrainingMode, normal, check
from brainpy.tools.checking import check_initializer, check_callable
from brainpy.types import Shape, Array

__all__ = [
'LeakyIntegrator',
'LIF',
'ExpIF',
'AdExIF',
'QuaIF',
'AdQuaIF',
'GIF',
'ALIFBellec2020',
'Izhikevich',
'HindmarshRose',
'FHN',
]


class LeakyIntegrator(NeuGroup):
r"""Leaky Integrator Model.
**Model Descriptions**
This class implements a leaky integrator model, in which its dynamics is
given by:
.. math::
\tau \frac{dV}{dt} = - (V(t) - V_{rest}) + RI(t)

where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting
membrane potential, :math:`\tau` is the time constant, and :math:`R` is the
resistance.

Parameters
----------
size: sequence of int, int
The size of the neuron group.
V_rest: float, JaxArray, ndarray, Initializer, callable
Resting membrane potential.
R: float, JaxArray, ndarray, Initializer, callable
Membrane resistance.
tau: float, JaxArray, ndarray, Initializer, callable
Membrane time constant.
V_initializer: JaxArray, ndarray, Initializer, callable
The initializer of membrane potential.
noise: JaxArray, ndarray, Initializer, callable
The noise added onto the membrane potential
method: str
The numerical integration method.
name: str
The group name.
"""

def __init__(
self,

# neuron group size
size: Shape,
keep_size: bool = False,

# neuron parameters
V_rest: Union[float, Array, Initializer, Callable] = 0.,
R: Union[float, Array, Initializer, Callable] = 1.,
tau: Union[float, Array, Initializer, Callable] = 10.,
V_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
noise: Union[float, Array, Initializer, Callable] = None,

# other parameter
name: str = None,
mode: Mode = normal,
method: str = 'exp_auto',
):
super(LeakyIntegrator, self).__init__(size=size,
mode=mode,
keep_size=keep_size,
name=name)
check(self.mode, (TrainingMode, NormalMode), self.__class__)

# parameters
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
self.tau = parameter(tau, self.varshape, allow_none=False)
self.R = parameter(R, self.varshape, allow_none=False)
self.noise = init_noise(noise, self.varshape)

# initializers
check_initializer(V_initializer, 'V_initializer')
self._V_initializer = V_initializer

# variables
self.V = variable(self._V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)

# integral
if self.noise is None:
self.integral = odeint(method=method, f=self.derivative)
else:
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

def derivative(self, V, t, I_ext):
return (-V + self.V_rest + self.R * I_ext) / self.tau

def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)

def update(self, tdi, x=None):
if x is not None: self.input += x
self.V.value = self.integral(self.V.value, tdi.t, self.input.value, tdi.dt)

def clear_input(self):
self.input[:] = 0.


class LIF(NeuGroup):
r"""Leaky integrate-and-fire neuron model.

@@ -31,7 +133,7 @@ class LIF(NeuGroup):

.. math::

\tau \frac{dV}{dt} = - (V(t) - V_{rest}) + I(t) \\
\tau \frac{dV}{dt} = - (V(t) - V_{rest}) + RI(t) \\
\text{after} \quad V(t) \gt V_{th}, V(t) = V_{reset} \quad
\text{last} \quad \tau_{ref} \quad \text{ms}

@@ -56,6 +158,8 @@ class LIF(NeuGroup):
Reset potential after spike.
V_th: float, JaxArray, ndarray, Initializer, callable
Threshold potential of spike.
R: float, JaxArray, ndarray, Initializer, callable
Membrane resistance.
tau: float, JaxArray, ndarray, Initializer, callable
Membrane time constant.
tau_ref: float, JaxArray, ndarray, Initializer, callable
@@ -64,8 +168,6 @@ class LIF(NeuGroup):
The initializer of membrane potential.
noise: JaxArray, ndarray, Initializer, callable
The noise added onto the membrane potential
noise_type: str
The type of the provided noise. Can be `value` or `func`.
method: str
The numerical integration method.
name: str
@@ -81,72 +183,118 @@ class LIF(NeuGroup):
def __init__(
self,
size: Shape,
V_rest: Union[float, Tensor, Initializer, Callable] = 0.,
V_reset: Union[float, Tensor, Initializer, Callable] = -5.,
V_th: Union[float, Tensor, Initializer, Callable] = 20.,
tau: Union[float, Tensor, Initializer, Callable] = 10.,
tau_ref: Union[float, Tensor, Initializer, Callable] = 1.,
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
noise: Union[float, Tensor, Initializer, Callable] = None,
noise_type: str = 'value',
keep_size: bool=False,
keep_size: bool = False,

# other parameter
V_rest: Union[float, Array, Initializer, Callable] = 0.,
V_reset: Union[float, Array, Initializer, Callable] = -5.,
V_th: Union[float, Array, Initializer, Callable] = 20.,
R: Union[float, Array, Initializer, Callable] = 1.,
tau: Union[float, Array, Initializer, Callable] = 10.,
tau_ref: Union[float, Array, Initializer, Callable] = None,
V_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
noise: Union[float, Array, Initializer, Callable] = None,
method: str = 'exp_auto',
name: str = None
name: str = None,

# training parameter
mode: Mode = normal,
spike_fun: Callable = bm.spike_with_sigmoid_grad,
):
# initialization
super(LIF, self).__init__(size=size, name=name)
super(LIF, self).__init__(size=size,
name=name,
keep_size=keep_size,
mode=mode)
check(self.mode, (TrainingMode, NormalMode), self.__class__)

# parameters
self.keep_size = keep_size
self.noise_type = noise_type
if noise_type not in ['func', 'value']:
raise ValueError(f'noise_type only supports `func` and `value`, but we got {noise_type}')
size = self.size if keep_size else self.num
self.V_rest = init_param(V_rest, size, allow_none=False)
self.V_reset = init_param(V_reset, size, allow_none=False)
self.V_th = init_param(V_th, size, allow_none=False)
self.tau = init_param(tau, size, allow_none=False)
self.tau_ref = init_param(tau_ref, size, allow_none=False)
if noise_type == 'func':
self.noise = noise
else:
self.noise = init_param(noise, size, allow_none=True)
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
self.V_reset = parameter(V_reset, self.varshape, allow_none=False)
self.V_th = parameter(V_th, self.varshape, allow_none=False)
self.tau = parameter(tau, self.varshape, allow_none=False)
self.R = parameter(R, self.varshape, allow_none=False)
self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True)
self.noise = init_noise(noise, self.varshape)
self.spike_fun = check_callable(spike_fun, 'spike_fun')

# initializers
check_initializer(V_initializer, 'V_initializer')
self._V_initializer = V_initializer

# variables
self.V = bm.Variable(init_param(V_initializer, size))
self.input = bm.Variable(bm.zeros(size))
self.spike = bm.Variable(bm.zeros(size, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(size) * -1e7)
self.refractory = bm.Variable(bm.zeros(size, dtype=bool))
self.V = variable(self._V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
sp_type = bm.dftype() if isinstance(mode, TrainingMode) else bool # the gradient of spike is a float
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
if self.tau_ref is not None:
self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, mode, self.varshape)
self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)

# integral
f = lambda V, t, I_ext: (-V + self.V_rest + I_ext) / self.tau
if self.noise is not None:
g = noise if (noise_type == 'func') else (lambda V, t, I_ext: self.noise / bm.sqrt(self.tau))
self.integral = sdeint(method=method, f=f, g=g)
if self.noise is None:
self.integral = odeint(method=method, f=self.derivative)
else:
self.integral = odeint(method=method, f=f)

def reset(self):
self.V.value = init_param(self._V_initializer, self.size if self.keep_size else self.num)
self.input[:] = 0
self.spike[:] = False
self.t_last_spike[:] = -1e7
self.refractory[:] = False

def update(self, t, dt):
refractory = (t - self.t_last_spike) <= self.tau_ref
V = self.integral(self.V, t, self.input, dt=dt)
V = bm.where(refractory, self.V, V)
spike = V >= self.V_th
self.t_last_spike.value = bm.where(spike, t, self.t_last_spike)
self.V.value = bm.where(spike, self.V_reset, V)
self.refractory.value = bm.logical_or(refractory, spike)
self.spike.value = spike
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

def derivative(self, V, t, I_ext):
return (-V + self.V_rest + self.R * I_ext) / self.tau

def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
if self.tau_ref is not None:
self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape)
self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)

def update(self, tdi, x=None):
t, dt = tdi.t, tdi.dt
if x is not None: self.input += x

# integrate membrane potential
V = self.integral(self.V.value, t, self.input.value, dt)

if self.tau_ref is not None:
# refractory
refractory = (t - self.t_last_spike) <= self.tau_ref
if isinstance(self.mode, TrainingMode):
refractory = stop_gradient(refractory)
V = bm.where(refractory, self.V, V)

# spike, refractory, spiking time, and membrane potential reset
if isinstance(self.mode, TrainingMode):
spike = self.spike_fun(V - self.V_th)
spike_no_grad = stop_gradient(spike)
V += (self.V_reset - V) * spike_no_grad
spike_ = spike_no_grad > 0.
# will be used in other place, like Delta Synapse, so stop its gradient
refractory = stop_gradient(bm.logical_or(refractory, spike_).value)
t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike).value)
else:
spike = V >= self.V_th
V = bm.where(spike, self.V_reset, V)
refractory = bm.logical_or(refractory, spike)
t_last_spike = bm.where(spike, t, self.t_last_spike)
self.V.value = V
self.spike.value = spike
self.refractory.value = refractory
self.t_last_spike.value = t_last_spike

else:
# spike, spiking time, and membrane potential reset
if isinstance(self.mode, TrainingMode):
spike = self.spike_fun(V - self.V_th)
spike_no_grad = stop_gradient(spike)
V += (self.V_reset - V) * spike_no_grad
else:
spike = V >= self.V_th
V = bm.where(spike, self.V_reset, V)
self.V.value = V
self.spike.value = spike

def clear_input(self):
self.input[:] = 0.


@@ -251,66 +399,94 @@ class ExpIF(NeuGroup):
def __init__(
self,
size: Shape,
V_rest: Union[float, Tensor, Initializer, Callable] = -65.,
V_reset: Union[float, Tensor, Initializer, Callable] = -68.,
V_th: Union[float, Tensor, Initializer, Callable] = -30.,
V_T: Union[float, Tensor, Initializer, Callable] = -59.9,
delta_T: Union[float, Tensor, Initializer, Callable] = 3.48,
R: Union[float, Tensor, Initializer, Callable] = 1.,
tau: Union[float, Tensor, Initializer, Callable] = 10.,
tau_ref: Union[float, Tensor, Initializer, Callable] = 1.7,
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
V_rest: Union[float, Array, Initializer, Callable] = -65.,
V_reset: Union[float, Array, Initializer, Callable] = -68.,
V_th: Union[float, Array, Initializer, Callable] = -30.,
V_T: Union[float, Array, Initializer, Callable] = -59.9,
delta_T: Union[float, Array, Initializer, Callable] = 3.48,
R: Union[float, Array, Initializer, Callable] = 1.,
tau: Union[float, Array, Initializer, Callable] = 10.,
tau_ref: Union[float, Array, Initializer, Callable] = None,
V_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
noise: Union[float, Array, Initializer, Callable] = None,
keep_size: bool = False,
mode: Mode = normal,
method: str = 'exp_auto',
name: str = None
):
# initialize
super(ExpIF, self).__init__(size=size, name=name)
super(ExpIF, self).__init__(size=size,
name=name,
mode=mode,
keep_size=keep_size, )
check(self.mode, (TrainingMode, NormalMode), self.__class__)

# parameters
self.V_rest = init_param(V_rest, self.num, allow_none=False)
self.V_reset = init_param(V_reset, self.num, allow_none=False)
self.V_th = init_param(V_th, self.num, allow_none=False)
self.V_T = init_param(V_T, self.num, allow_none=False)
self.delta_T = init_param(delta_T, self.num, allow_none=False)
self.tau_ref = init_param(tau_ref, self.num, allow_none=False)
self.tau = init_param(tau, self.num, allow_none=False)
self.R = init_param(R, self.num, allow_none=False)
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
self.V_reset = parameter(V_reset, self.varshape, allow_none=False)
self.V_th = parameter(V_th, self.varshape, allow_none=False)
self.V_T = parameter(V_T, self.varshape, allow_none=False)
self.delta_T = parameter(delta_T, self.varshape, allow_none=False)
self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True)
self.tau = parameter(tau, self.varshape, allow_none=False)
self.R = parameter(R, self.varshape, allow_none=False)
self.noise = init_noise(noise, self.varshape)

# initializers
check_initializer(V_initializer, 'V_initializer')
self._V_initializer = V_initializer

# variables
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)
self.V = variable(V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, mode, self.varshape)
if self.tau_ref is not None:
self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)

# integral
self.integral = odeint(method=method, f=self.derivative)
if self.noise is None:
self.integral = odeint(method=method, f=self.derivative)
else:
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

def reset(self):
self.V.value = init_param(self._V_initializer, (self.num,))
self.input[:] = 0
self.spike[:] = False
self.t_last_spike[:] = -1e7
self.refractory[:] = False
def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape)
if self.tau_ref is not None:
self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)

def derivative(self, V, t, I_ext):
exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T)
dvdt = (- (V - self.V_rest) + exp_v + self.R * I_ext) / self.tau
return dvdt

def update(self, t, dt):
refractory = (t - self.t_last_spike) <= self.tau_ref
V = self.integral(self.V, t, self.input, dt=dt)
V = bm.where(refractory, self.V, V)
spike = self.V_th <= V
self.t_last_spike.value = bm.where(spike, t, self.t_last_spike)
self.V.value = bm.where(spike, self.V_reset, V)
self.refractory.value = bm.logical_or(refractory, spike)
def update(self, tdi, x=None):
t, dt = tdi.t, tdi.dt
if x is not None: self.input += x
V = self.integral(self.V.value, t, self.input.value, dt)

if self.tau_ref is not None:
refractory = (t - self.t_last_spike) <= self.tau_ref
V = bm.where(refractory, self.V, V)
spike = self.V_th <= V
t_last_spike = bm.where(spike, t, self.t_last_spike)
V = bm.where(spike, self.V_reset, V)
self.refractory.value = bm.logical_or(refractory, spike)
else:
spike = self.V_th <= V
t_last_spike = bm.where(spike, t, self.t_last_spike)
V = bm.where(spike, self.V_reset, V)

self.V.value = V
self.spike.value = spike
self.t_last_spike.value = t_last_spike

def clear_input(self):
self.input[:] = 0.


@@ -390,34 +566,42 @@ class AdExIF(NeuGroup):
def __init__(
self,
size: Shape,
V_rest: Union[float, Tensor, Initializer, Callable] = -65.,
V_reset: Union[float, Tensor, Initializer, Callable] = -68.,
V_th: Union[float, Tensor, Initializer, Callable] = -30.,
V_T: Union[float, Tensor, Initializer, Callable] = -59.9,
delta_T: Union[float, Tensor, Initializer, Callable] = 3.48,
a: Union[float, Tensor, Initializer, Callable] = 1.,
b: Union[float, Tensor, Initializer, Callable] = 1.,
tau: Union[float, Tensor, Initializer, Callable] = 10.,
tau_w: Union[float, Tensor, Initializer, Callable] = 30.,
R: Union[float, Tensor, Initializer, Callable] = 1.,
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
w_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
V_rest: Union[float, Array, Initializer, Callable] = -65.,
V_reset: Union[float, Array, Initializer, Callable] = -68.,
V_th: Union[float, Array, Initializer, Callable] = -30.,
V_T: Union[float, Array, Initializer, Callable] = -59.9,
delta_T: Union[float, Array, Initializer, Callable] = 3.48,
a: Union[float, Array, Initializer, Callable] = 1.,
b: Union[float, Array, Initializer, Callable] = 1.,
tau: Union[float, Array, Initializer, Callable] = 10.,
tau_w: Union[float, Array, Initializer, Callable] = 30.,
R: Union[float, Array, Initializer, Callable] = 1.,
V_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
w_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
noise: Union[float, Array, Initializer, Callable] = None,
method: str = 'exp_auto',
keep_size: bool = False,
mode: Mode = normal,
name: str = None
):
super(AdExIF, self).__init__(size=size, name=name)
super(AdExIF, self).__init__(size=size,
keep_size=keep_size,
name=name,
mode=mode, )
check(self.mode, (TrainingMode, NormalMode), self.__class__)

# parameters
self.V_rest = init_param(V_rest, self.num, allow_none=False)
self.V_reset = init_param(V_reset, self.num, allow_none=False)
self.V_th = init_param(V_th, self.num, allow_none=False)
self.V_T = init_param(V_T, self.num, allow_none=False)
self.delta_T = init_param(delta_T, self.num, allow_none=False)
self.a = init_param(a, self.num, allow_none=False)
self.b = init_param(b, self.num, allow_none=False)
self.tau = init_param(tau, self.num, allow_none=False)
self.tau_w = init_param(tau_w, self.num, allow_none=False)
self.R = init_param(R, self.num, allow_none=False)
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
self.V_reset = parameter(V_reset, self.varshape, allow_none=False)
self.V_th = parameter(V_th, self.varshape, allow_none=False)
self.V_T = parameter(V_T, self.varshape, allow_none=False)
self.delta_T = parameter(delta_T, self.varshape, allow_none=False)
self.a = parameter(a, self.varshape, allow_none=False)
self.b = parameter(b, self.varshape, allow_none=False)
self.tau = parameter(tau, self.varshape, allow_none=False)
self.tau_w = parameter(tau_w, self.varshape, allow_none=False)
self.R = parameter(R, self.varshape, allow_none=False)
self.noise = init_noise(noise, self.varshape, num_vars=2)

# initializers
check_initializer(V_initializer, 'V_initializer')
@@ -426,25 +610,28 @@ class AdExIF(NeuGroup):
self._w_initializer = w_initializer

# variables
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.w = bm.Variable(init_param(w_initializer, (self.num,)))
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.V = variable(V_initializer, mode, self.varshape)
self.w = variable(w_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
sp_type = bm.dftype() if isinstance(mode, BatchingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)

# functions
self.integral = odeint(method=method, f=self.derivative)
if self.noise is None:
self.integral = odeint(method=method, f=self.derivative)
else:
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

def reset(self):
self.V.value = init_param(self._V_initializer, (self.num,))
self.w.value = init_param(self._w_initializer, (self.num,))
self.input[:] = 0
self.spike[:] = False
self.refractory[:] = False
def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.w.value = variable(self._w_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)

def dV(self, V, t, w, I_ext):
dVdt = (- V + self.V_rest + self.delta_T * bm.exp((V - self.V_T) / self.delta_T) -
self.R * w + self.R * I_ext) / self.tau
exp = self.delta_T * bm.exp((V - self.V_T) / self.delta_T)
dVdt = (- V + self.V_rest + exp - self.R * w + self.R * I_ext) / self.tau
return dVdt

def dw(self, w, t, V):
@@ -455,12 +642,16 @@ class AdExIF(NeuGroup):
def derivative(self):
return JointEq([self.dV, self.dw])

def update(self, t, dt):
V, w = self.integral(self.V, self.w, t, self.input, dt=dt)
def update(self, tdi, x=None):
t, dt = tdi.t, tdi.dt
if x is not None: self.input += x
V, w = self.integral(self.V.value, self.w.value, t, self.input.value, dt)
spike = V >= self.V_th
self.V.value = bm.where(spike, self.V_reset, V)
self.w.value = bm.where(spike, w + self.b, w)
self.spike.value = spike

def clear_input(self):
self.input[:] = 0.


@@ -534,65 +725,91 @@ class QuaIF(NeuGroup):
def __init__(
self,
size: Shape,
V_rest: Union[float, Tensor, Initializer, Callable] = -65.,
V_reset: Union[float, Tensor, Initializer, Callable] = -68.,
V_th: Union[float, Tensor, Initializer, Callable] = -30.,
V_c: Union[float, Tensor, Initializer, Callable] = -50.0,
c: Union[float, Tensor, Initializer, Callable] = .07,
R: Union[float, Tensor, Initializer, Callable] = 1.,
tau: Union[float, Tensor, Initializer, Callable] = 10.,
tau_ref: Union[float, Tensor, Initializer, Callable] = 0.,
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
V_rest: Union[float, Array, Initializer, Callable] = -65.,
V_reset: Union[float, Array, Initializer, Callable] = -68.,
V_th: Union[float, Array, Initializer, Callable] = -30.,
V_c: Union[float, Array, Initializer, Callable] = -50.0,
c: Union[float, Array, Initializer, Callable] = .07,
R: Union[float, Array, Initializer, Callable] = 1.,
tau: Union[float, Array, Initializer, Callable] = 10.,
tau_ref: Union[float, Array, Initializer, Callable] = None,
V_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
noise: Union[float, Array, Initializer, Callable] = None,
keep_size: bool = False,
mode: Mode = normal,
method: str = 'exp_auto',
name: str = None
):
# initialization
super(QuaIF, self).__init__(size=size, name=name)
super(QuaIF, self).__init__(size=size,
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (TrainingMode, NormalMode), self.__class__)

# parameters
self.V_rest = init_param(V_rest, self.num, allow_none=False)
self.V_reset = init_param(V_reset, self.num, allow_none=False)
self.V_th = init_param(V_th, self.num, allow_none=False)
self.V_c = init_param(V_c, self.num, allow_none=False)
self.c = init_param(c, self.num, allow_none=False)
self.R = init_param(R, self.num, allow_none=False)
self.tau = init_param(tau, self.num, allow_none=False)
self.tau_ref = init_param(tau_ref, self.num, allow_none=False)
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
self.V_reset = parameter(V_reset, self.varshape, allow_none=False)
self.V_th = parameter(V_th, self.varshape, allow_none=False)
self.V_c = parameter(V_c, self.varshape, allow_none=False)
self.c = parameter(c, self.varshape, allow_none=False)
self.R = parameter(R, self.varshape, allow_none=False)
self.tau = parameter(tau, self.varshape, allow_none=False)
self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True)
self.noise = init_noise(noise, self.varshape, num_vars=1)

# initializers
check_initializer(V_initializer, '_V_initializer', allow_none=False)
self._V_initializer = V_initializer

# variables
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)
self.V = variable(V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, mode, self.varshape)
if self.tau_ref is not None:
self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)

# integral
self.integral = odeint(method=method, f=self.derivative)
if self.noise is None:
self.integral = odeint(method=method, f=self.derivative)
else:
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

def reset(self):
self.V.value = init_param(self._V_initializer, (self.num,))
self.input[:] = 0
self.spike[:] = False
self.t_last_spike[:] = -1e7
self.refractory[:] = False
def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape)
if self.tau_ref is not None:
self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)

def derivative(self, V, t, I_ext):
dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I_ext) / self.tau
return dVdt

def update(self, t, dt, **kwargs):
refractory = (t - self.t_last_spike) <= self.tau_ref
V = self.integral(self.V, t, self.input, dt=dt)
V = bm.where(refractory, self.V, V)
spike = self.V_th <= V
self.t_last_spike.value = bm.where(spike, t, self.t_last_spike)
self.V.value = bm.where(spike, self.V_reset, V)
self.refractory.value = bm.logical_or(refractory, spike)
def update(self, tdi, x=None):
t, dt = tdi.t, tdi.dt
if x is not None: self.input += x
V = self.integral(self.V.value, t, self.input.value, dt)
if self.tau_ref is not None:
refractory = (t - self.t_last_spike) <= self.tau_ref
V = bm.where(refractory, self.V, V)
spike = self.V_th <= V
t_last_spike = bm.where(spike, t, self.t_last_spike)
V = bm.where(spike, self.V_reset, V)
self.refractory.value = bm.logical_or(refractory, spike)
else:
spike = self.V_th <= V
t_last_spike = bm.where(spike, t, self.t_last_spike)
V = bm.where(spike, self.V_reset, V)
self.V.value = V
self.spike.value = spike
self.t_last_spike.value = t_last_spike

def clear_input(self):
self.input[:] = 0.


@@ -676,32 +893,40 @@ class AdQuaIF(NeuGroup):
def __init__(
self,
size: Shape,
V_rest: Union[float, Tensor, Initializer, Callable] = -65.,
V_reset: Union[float, Tensor, Initializer, Callable] = -68.,
V_th: Union[float, Tensor, Initializer, Callable] = -30.,
V_c: Union[float, Tensor, Initializer, Callable] = -50.0,
a: Union[float, Tensor, Initializer, Callable] = 1.,
b: Union[float, Tensor, Initializer, Callable] = .1,
c: Union[float, Tensor, Initializer, Callable] = .07,
tau: Union[float, Tensor, Initializer, Callable] = 10.,
tau_w: Union[float, Tensor, Initializer, Callable] = 10.,
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
w_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
V_rest: Union[float, Array, Initializer, Callable] = -65.,
V_reset: Union[float, Array, Initializer, Callable] = -68.,
V_th: Union[float, Array, Initializer, Callable] = -30.,
V_c: Union[float, Array, Initializer, Callable] = -50.0,
a: Union[float, Array, Initializer, Callable] = 1.,
b: Union[float, Array, Initializer, Callable] = .1,
c: Union[float, Array, Initializer, Callable] = .07,
tau: Union[float, Array, Initializer, Callable] = 10.,
tau_w: Union[float, Array, Initializer, Callable] = 10.,
V_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
w_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
noise: Union[float, Array, Initializer, Callable] = None,
method: str = 'exp_auto',
keep_size: bool = False,
mode: Mode = normal,
name: str = None
):
super(AdQuaIF, self).__init__(size=size, name=name)
super(AdQuaIF, self).__init__(size=size,
keep_size=keep_size,
name=name,
mode=mode, )
check(self.mode, (TrainingMode, NormalMode), self.__class__)

# parameters
self.V_rest = init_param(V_rest, self.num, allow_none=False)
self.V_reset = init_param(V_reset, self.num, allow_none=False)
self.V_th = init_param(V_th, self.num, allow_none=False)
self.V_c = init_param(V_c, self.num, allow_none=False)
self.c = init_param(c, self.num, allow_none=False)
self.a = init_param(a, self.num, allow_none=False)
self.b = init_param(b, self.num, allow_none=False)
self.tau = init_param(tau, self.num, allow_none=False)
self.tau_w = init_param(tau_w, self.num, allow_none=False)
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
self.V_reset = parameter(V_reset, self.varshape, allow_none=False)
self.V_th = parameter(V_th, self.varshape, allow_none=False)
self.V_c = parameter(V_c, self.varshape, allow_none=False)
self.c = parameter(c, self.varshape, allow_none=False)
self.a = parameter(a, self.varshape, allow_none=False)
self.b = parameter(b, self.varshape, allow_none=False)
self.tau = parameter(tau, self.varshape, allow_none=False)
self.tau_w = parameter(tau_w, self.varshape, allow_none=False)
self.noise = init_noise(noise, self.varshape, num_vars=2)

# initializers
check_initializer(V_initializer, 'V_initializer', allow_none=False)
@@ -710,21 +935,26 @@ class AdQuaIF(NeuGroup):
self._w_initializer = w_initializer

# variables
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.w = bm.Variable(init_param(w_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
self.V = variable(V_initializer, mode, self.varshape)
self.w = variable(w_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)

# integral
self.integral = odeint(method=method, f=self.derivative)
if self.noise is None:
self.integral = odeint(method=method, f=self.derivative)
else:
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

def reset(self):
self.V.value = init_param(self._V_initializer, (self.num,))
self.w.value = init_param(self._w_initializer, (self.num,))
self.input[:] = 0
self.spike[:] = False
self.refractory[:] = False
def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.w.value = variable(self._w_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)

def dV(self, V, t, w, I_ext):
dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + I_ext) / self.tau
@@ -738,12 +968,16 @@ class AdQuaIF(NeuGroup):
def derivative(self):
return JointEq([self.dV, self.dw])

def update(self, t, dt):
V, w = self.integral(self.V, self.w, t, self.input, dt=dt)
def update(self, tdi, x=None):
t, dt = tdi.t, tdi.dt
if x is not None: self.input += x
V, w = self.integral(self.V.value, self.w.value, t, self.input.value, dt)
spike = self.V_th <= V
self.V.value = bm.where(spike, self.V_reset, V)
self.w.value = bm.where(spike, w + self.b, w)
self.spike.value = spike

def clear_input(self):
self.input[:] = 0.


@@ -832,45 +1066,57 @@ class GIF(NeuGroup):
def __init__(
self,
size: Shape,
V_rest: Union[float, Tensor, Initializer, Callable] = -70.,
V_reset: Union[float, Tensor, Initializer, Callable] = -70.,
V_th_inf: Union[float, Tensor, Initializer, Callable] = -50.,
V_th_reset: Union[float, Tensor, Initializer, Callable] = -60.,
R: Union[float, Tensor, Initializer, Callable] = 20.,
tau: Union[float, Tensor, Initializer, Callable] = 20.,
a: Union[float, Tensor, Initializer, Callable] = 0.,
b: Union[float, Tensor, Initializer, Callable] = 0.01,
k1: Union[float, Tensor, Initializer, Callable] = 0.2,
k2: Union[float, Tensor, Initializer, Callable] = 0.02,
R1: Union[float, Tensor, Initializer, Callable] = 0.,
R2: Union[float, Tensor, Initializer, Callable] = 1.,
A1: Union[float, Tensor, Initializer, Callable] = 0.,
A2: Union[float, Tensor, Initializer, Callable] = 0.,
V_initializer: Union[Initializer, Callable, Tensor] = OneInit(-70.),
I1_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
I2_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
Vth_initializer: Union[Initializer, Callable, Tensor] = OneInit(-50.),
V_rest: Union[float, Array, Initializer, Callable] = -70.,
V_reset: Union[float, Array, Initializer, Callable] = -70.,
V_th_inf: Union[float, Array, Initializer, Callable] = -50.,
V_th_reset: Union[float, Array, Initializer, Callable] = -60.,
R: Union[float, Array, Initializer, Callable] = 20.,
tau: Union[float, Array, Initializer, Callable] = 20.,
a: Union[float, Array, Initializer, Callable] = 0.,
b: Union[float, Array, Initializer, Callable] = 0.01,
k1: Union[float, Array, Initializer, Callable] = 0.2,
k2: Union[float, Array, Initializer, Callable] = 0.02,
R1: Union[float, Array, Initializer, Callable] = 0.,
R2: Union[float, Array, Initializer, Callable] = 1.,
A1: Union[float, Array, Initializer, Callable] = 0.,
A2: Union[float, Array, Initializer, Callable] = 0.,
V_initializer: Union[Initializer, Callable, Array] = OneInit(-70.),
I1_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
I2_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
Vth_initializer: Union[Initializer, Callable, Array] = OneInit(-50.),
noise: Union[float, Array, Initializer, Callable] = None,
method: str = 'exp_auto',
name: str = None
keep_size: bool = False,
name: str = None,

# parameter for training
mode: Mode = normal,
spike_fun: Callable = bm.spike_with_sigmoid_grad,
):
# initialization
super(GIF, self).__init__(size=size, name=name)
super(GIF, self).__init__(size=size,
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (TrainingMode, NormalMode), self.__class__)

# params
self.V_rest = init_param(V_rest, self.num, allow_none=False)
self.V_reset = init_param(V_reset, self.num, allow_none=False)
self.V_th_inf = init_param(V_th_inf, self.num, allow_none=False)
self.V_th_reset = init_param(V_th_reset, self.num, allow_none=False)
self.R = init_param(R, self.num, allow_none=False)
self.tau = init_param(tau, self.num, allow_none=False)
self.a = init_param(a, self.num, allow_none=False)
self.b = init_param(b, self.num, allow_none=False)
self.k1 = init_param(k1, self.num, allow_none=False)
self.k2 = init_param(k2, self.num, allow_none=False)
self.R1 = init_param(R1, self.num, allow_none=False)
self.R2 = init_param(R2, self.num, allow_none=False)
self.A1 = init_param(A1, self.num, allow_none=False)
self.A2 = init_param(A2, self.num, allow_none=False)
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
self.V_reset = parameter(V_reset, self.varshape, allow_none=False)
self.V_th_inf = parameter(V_th_inf, self.varshape, allow_none=False)
self.V_th_reset = parameter(V_th_reset, self.varshape, allow_none=False)
self.R = parameter(R, self.varshape, allow_none=False)
self.tau = parameter(tau, self.varshape, allow_none=False)
self.a = parameter(a, self.varshape, allow_none=False)
self.b = parameter(b, self.varshape, allow_none=False)
self.k1 = parameter(k1, self.varshape, allow_none=False)
self.k2 = parameter(k2, self.varshape, allow_none=False)
self.R1 = parameter(R1, self.varshape, allow_none=False)
self.R2 = parameter(R2, self.varshape, allow_none=False)
self.A1 = parameter(A1, self.varshape, allow_none=False)
self.A2 = parameter(A2, self.varshape, allow_none=False)
self.noise = init_noise(noise, self.varshape, num_vars=4)
self.spike_fun = check_callable(spike_fun, 'spike_fun')

# initializers
check_initializer(V_initializer, 'V_initializer')
@@ -883,23 +1129,28 @@ class GIF(NeuGroup):
self._Vth_initializer = Vth_initializer

# variables
self.I1 = bm.Variable(init_param(I1_initializer, (self.num,)))
self.I2 = bm.Variable(init_param(I2_initializer, (self.num,)))
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.V_th = bm.Variable(init_param(Vth_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.I1 = variable(I1_initializer, mode, self.varshape)
self.I2 = variable(I2_initializer, mode, self.varshape)
self.V_th = variable(Vth_initializer, mode, self.varshape)
self.V = variable(V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)

# integral
self.integral = odeint(method=method, f=self.derivative)
if self.noise is None:
self.integral = odeint(method=method, f=self.derivative)
else:
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

def reset(self):
self.V.value = init_param(self._V_initializer, (self.num,))
self.I1.value = init_param(self._I1_initializer, (self.num,))
self.I2.value = init_param(self._I2_initializer, (self.num,))
self.V_th.value = init_param(self._Vth_initializer, (self.num,))
self.input[:] = 0
self.spike[:] = False
def reset_state(self, batch_size=None):
self.I1.value = variable(self._I1_initializer, batch_size, self.varshape)
self.I2.value = variable(self._I2_initializer, batch_size, self.varshape)
self.V_th.value = variable(self._Vth_initializer, batch_size, self.varshape)
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)

def dI1(self, I1, t):
return - self.k1 * I1
@@ -917,19 +1168,199 @@ class GIF(NeuGroup):
def derivative(self):
return JointEq([self.dI1, self.dI2, self.dVth, self.dV])

def update(self, t, dt):
def update(self, tdi, x=None):
t, dt = tdi.t, tdi.dt

# integral
if x is not None: self.input += x
I1, I2, V_th, V = self.integral(self.I1, self.I2, self.V_th, self.V, t, self.input, dt=dt)
spike = self.V_th <= V
V = bm.where(spike, self.V_reset, V)
I1 = bm.where(spike, self.R1 * I1 + self.A1, I1)
I2 = bm.where(spike, self.R2 * I2 + self.A2, I2)
reset_th = bm.logical_and(V_th < self.V_th_reset, spike)
V_th = bm.where(reset_th, self.V_th_reset, V_th)

# spike and resets
if isinstance(self.mode, TrainingMode):
spike = self.spike_fun(V - self.V_th)
V += (self.V_reset - V) * spike
I1 += spike * (self.R1 * I1 + self.A1 - I1)
I2 += spike * (self.R2 * I2 + self.A2 - I2)
reset_th = self.spike_fun(self.V_th_reset - V_th) * spike
V_th += reset_th * (self.V_th_reset - V_th)
else:
spike = self.V_th <= V
V = bm.where(spike, self.V_reset, V)
I1 = bm.where(spike, self.R1 * I1 + self.A1, I1)
I2 = bm.where(spike, self.R2 * I2 + self.A2, I2)
reset_th = bm.logical_and(V_th < self.V_th_reset, spike)
V_th = bm.where(reset_th, self.V_th_reset, V_th)
self.spike.value = spike
self.I1.value = I1
self.I2.value = I2
self.V_th.value = V_th
self.V.value = V

def clear_input(self):
self.input[:] = 0.


class ALIFBellec2020(NeuGroup):
r"""Leaky Integrate-and-Fire model with SFA [1]_.

This model is similar to the GLIF2 model in the Technical White Paper
on generalized LIF (GLIF) models from AllenInstitute [2]_.

Formally, this model is given by:

.. math::

\tau \dot{V} = -(V - V_{\mathrm{rest}}) + R*I \\
\tau_a \dot{a} = -a

Once a spike is induced by :math:`V(t) > V_{\mathrm{th}} + \beta a`, then

.. math::

V \gets V - V_{\mathrm{th}} \\
a \gets a + 1


References
----------
.. [1] Bellec, Guillaume, et al. "A solution to the learning dilemma for
recurrent networks of spiking neurons."
Nature communications 11.1 (2020): 1-15.
.. [2] Allen Institute: Cell Types Database. © 2018 Allen Institute for
Brain Science. Allen Cell Types Database, cell feature search.
Available from: celltypes.brain-map.org/data (2018).
"""

def __init__(
self,
size: Shape,
keep_size: bool = False,

# model parameters
V_rest: Union[float, Array, Initializer, Callable] = -70.,
V_th: Union[float, Array, Initializer, Callable] = -60.,
R: Union[float, Array, Initializer, Callable] = 1.,
beta: Union[float, Array, Initializer, Callable] = 1.6,
tau: Union[float, Array, Initializer, Callable] = 20.,
tau_a: Union[float, Array, Initializer, Callable] = 2000.,
tau_ref: Union[float, Array, Initializer, Callable] = None,
noise: Union[float, Array, Initializer, Callable] = None,

# initializers
V_initializer: Union[Initializer, Callable, Array] = OneInit(-70.),
a_initializer: Union[Initializer, Callable, Array] = OneInit(-50.),

# parameter for training
spike_fun: Callable = bm.spike_with_linear_grad,

# other parameters
method: str = 'exp_auto',
name: str = None,
mode: Mode = normal,
eprop: bool = False
):
super(ALIFBellec2020, self).__init__(name=name,
size=size,
keep_size=keep_size,
mode=mode)
check(self.mode, (TrainingMode, NormalMode), self.__class__)

# parameters
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
self.V_th = parameter(V_th, self.varshape, allow_none=False)
self.R = parameter(R, self.varshape, allow_none=False)
self.beta = parameter(beta, self.varshape, allow_none=False)
self.tau = parameter(tau, self.varshape, allow_none=False)
self.tau_a = parameter(tau_a, self.varshape, allow_none=False)
self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True)
self.noise = init_noise(noise, self.varshape, num_vars=2)
self.spike_fun = check_callable(spike_fun, 'spike_fun')
self.eprop = eprop

# initializers
check_initializer(V_initializer, 'V_initializer')
check_initializer(a_initializer, 'a_initializer')
self._V_initializer = V_initializer
self._a_initializer = a_initializer

# variables
self.a = variable(a_initializer, mode, self.varshape)
self.V = variable(V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
if self.tau_ref is not None:
self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, mode, self.varshape)
self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)

# integral
if self.noise is None:
self.integral = odeint(method=method, f=self.derivative)
else:
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

def da(self, a, t):
return -a / self.tau_a

def dV(self, V, t, I_ext):
return (- (V - self.V_rest) + self.R * I_ext) / self.tau

@property
def derivative(self):
return JointEq([self.dV, self.da])

def reset_state(self, batch_size=None):
self.a.value = variable(self._a_initializer, batch_size, self.varshape)
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
if self.tau_ref is not None:
self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape)
self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)

def update(self, tdi, x=None):
t, dt = tdi.t, tdi.dt

# integral
if x is not None: self.input += x
V, a = self.integral(self.V, self.a, t, self.input, dt)

if self.tau_ref is not None:
# refractory
refractory = (t - self.t_last_spike) <= self.tau_ref
if isinstance(self.mode, TrainingMode):
refractory = stop_gradient(refractory)
V = bm.where(refractory, self.V, V)
# spike and reset
if isinstance(self.mode, TrainingMode):
spike = self.spike_fun((V - self.V_th - self.beta * self.a) / self.V_th)
V -= self.V_th * (stop_gradient(spike) if self.eprop else spike)
# will be used in other place, like Delta Synapse, so stop its gradient
spike_ = spike > 0.
refractory = stop_gradient(bm.logical_or(refractory, spike_).value)
t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike).value)
else:
spike = V >= (self.V_th + self.beta * self.a)
refractory = bm.logical_or(refractory, spike)
t_last_spike = bm.where(spike, t, self.t_last_spike)
V -= self.V_th * spike
self.refractory.value = refractory
self.t_last_spike.value = t_last_spike

else:
# spike and reset
if isinstance(self.mode, TrainingMode):
spike = self.spike_fun((V - self.V_th - self.beta * self.a) / self.V_th)
V -= self.V_th * (stop_gradient(spike) if self.eprop else spike)
else:
spike = V >= (self.V_th + self.beta * self.a)
V -= self.V_th * spike
self.spike.value = spike
self.V.value = V
self.a.value = a + spike

def clear_input(self):
self.input[:] = 0.


@@ -1004,27 +1435,37 @@ class Izhikevich(NeuGroup):
def __init__(
self,
size: Shape,
a: Union[float, Tensor, Initializer, Callable] = 0.02,
b: Union[float, Tensor, Initializer, Callable] = 0.20,
c: Union[float, Tensor, Initializer, Callable] = -65.,
d: Union[float, Tensor, Initializer, Callable] = 8.,
V_th: Union[float, Tensor, Initializer, Callable] = 30.,
tau_ref: Union[float, Tensor, Initializer, Callable] = 0.,
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
u_initializer: Union[Initializer, Callable, Tensor] = OneInit(),
a: Union[float, Array, Initializer, Callable] = 0.02,
b: Union[float, Array, Initializer, Callable] = 0.20,
c: Union[float, Array, Initializer, Callable] = -65.,
d: Union[float, Array, Initializer, Callable] = 8.,
V_th: Union[float, Array, Initializer, Callable] = 30.,
tau_ref: Union[float, Array, Initializer, Callable] = None,
V_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
u_initializer: Union[Initializer, Callable, Array] = OneInit(),
noise: Union[float, Array, Initializer, Callable] = None,
method: str = 'exp_auto',
mode: Mode = normal,
spike_fun: Callable = bm.spike_with_sigmoid_grad,
keep_size: bool = False,
name: str = None
):
# initialization
super(Izhikevich, self).__init__(size=size, name=name)
super(Izhikevich, self).__init__(size=size,
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (TrainingMode, NormalMode), self.__class__)

# params
self.a = init_param(a, self.num, allow_none=False)
self.b = init_param(b, self.num, allow_none=False)
self.c = init_param(c, self.num, allow_none=False)
self.d = init_param(d, self.num, allow_none=False)
self.V_th = init_param(V_th, self.num, allow_none=False)
self.tau_ref = init_param(tau_ref, self.num, allow_none=False)
self.a = parameter(a, self.varshape, allow_none=False)
self.b = parameter(b, self.varshape, allow_none=False)
self.c = parameter(c, self.varshape, allow_none=False)
self.d = parameter(d, self.varshape, allow_none=False)
self.V_th = parameter(V_th, self.varshape, allow_none=False)
self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True)
self.noise = init_noise(noise, self.varshape, num_vars=2)
self.spike_fun = check_callable(spike_fun, 'spike_fun')

# initializers
check_initializer(V_initializer, 'V_initializer', allow_none=False)
@@ -1033,23 +1474,30 @@ class Izhikevich(NeuGroup):
self._u_initializer = u_initializer

# variables
self.u = bm.Variable(init_param(u_initializer, (self.num,)))
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)
self.u = variable(u_initializer, mode, self.varshape)
self.V = variable(V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
if self.tau_ref is not None:
self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, mode, self.varshape)
self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)

# functions
self.integral = odeint(method=method, f=JointEq([self.dV, self.du]))

def reset(self):
self.V.value = init_param(self._V_initializer, (self.num,))
self.u.value = init_param(self._u_initializer, (self.num,))
self.input[:] = 0
self.spike[:] = False
self.refractory[:] = False
self.t_last_spike[:] = -1e7
if self.noise is None:
self.integral = odeint(method=method, f=JointEq([self.dV, self.du]))
else:
self.integral = sdeint(method=method, f=JointEq([self.dV, self.du]), g=self.noise)

def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.u.value = variable(self._u_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
if self.tau_ref is not None:
self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape)
self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)

def dV(self, V, t, u, I_ext):
dVdt = 0.04 * V * V + 5 * V + 140 - u + I_ext
@@ -1059,16 +1507,55 @@ class Izhikevich(NeuGroup):
dudt = self.a * (self.b * V - u)
return dudt

def update(self, t, dt):
V, u = self.integral(self.V, self.u, t, self.input, dt=dt)
refractory = (t - self.t_last_spike) <= self.tau_ref
V = bm.where(refractory, self.V, V)
spike = self.V_th <= V
self.t_last_spike.value = bm.where(spike, t, self.t_last_spike)
self.V.value = bm.where(spike, self.c, V)
self.u.value = bm.where(spike, u + self.d, u)
self.refractory.value = bm.logical_or(refractory, spike)
def update(self, tdi, x=None):
t, dt = tdi.t, tdi.dt

# integrate membrane potential
if x is not None: self.input += x
V, u = self.integral(self.V, self.u, t, self.input, dt)

if self.tau_ref is not None:
refractory = (t - self.t_last_spike) <= self.tau_ref
if isinstance(self.mode, TrainingMode):
refractory = stop_gradient(refractory)
V = bm.where(refractory, self.V, V)

# spike, refractory, and reset membrane potential
if isinstance(self.mode, TrainingMode):
spike = self.spike_fun(V - self.V_th)
spike_no_grad = stop_gradient(spike)
V += spike_no_grad * (self.c - self.V_th)
u += spike_no_grad * self.d
spike_ = spike_no_grad > 0.
refractory = stop_gradient(bm.logical_or(refractory, spike_).value)
t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike).value)
else:
spike = self.V_th <= V
V = bm.where(spike, self.c, V)
u = bm.where(spike, u + self.d, u)
refractory = bm.logical_or(refractory, spike)
t_last_spike = bm.where(spike, t, self.t_last_spike)
self.refractory.value = refractory
self.t_last_spike.value = t_last_spike

else:
# spike, refractory, and reset membrane potential
if isinstance(self.mode, TrainingMode):
spike = self.spike_fun(V - self.V_th)
spike_no_grad = stop_gradient(spike)
V += spike_no_grad * (self.c - self.V_th)
u += spike_no_grad * self.d
else:
spike = self.V_th <= V
V = bm.where(spike, self.c, V)
u = bm.where(spike, u + self.d, u)

# finally
self.V.value = V
self.u.value = u
self.spike.value = spike

def clear_input(self):
self.input[:] = 0.


@@ -1173,32 +1660,44 @@ class HindmarshRose(NeuGroup):
def __init__(
self,
size: Shape,
a: Union[float, Tensor, Initializer, Callable] = 1.,
b: Union[float, Tensor, Initializer, Callable] = 3.,
c: Union[float, Tensor, Initializer, Callable] = 1.,
d: Union[float, Tensor, Initializer, Callable] = 5.,
r: Union[float, Tensor, Initializer, Callable] = 0.01,
s: Union[float, Tensor, Initializer, Callable] = 4.,
V_rest: Union[float, Tensor, Initializer, Callable] = -1.6,
V_th: Union[float, Tensor, Initializer, Callable] = 1.0,
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
y_initializer: Union[Initializer, Callable, Tensor] = OneInit(-10.),
z_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
a: Union[float, Array, Initializer, Callable] = 1.,
b: Union[float, Array, Initializer, Callable] = 3.,
c: Union[float, Array, Initializer, Callable] = 1.,
d: Union[float, Array, Initializer, Callable] = 5.,
r: Union[float, Array, Initializer, Callable] = 0.01,
s: Union[float, Array, Initializer, Callable] = 4.,
V_rest: Union[float, Array, Initializer, Callable] = -1.6,
V_th: Union[float, Array, Initializer, Callable] = 1.0,
V_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
y_initializer: Union[Initializer, Callable, Array] = OneInit(-10.),
z_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
noise: Union[float, Array, Initializer, Callable] = None,
method: str = 'exp_auto',
name: str = None
keep_size: bool = False,
name: str = None,

# parameters for training
mode: Mode = normal,
spike_fun: Callable = bm.spike2_with_sigmoid_grad,
):
# initialization
super(HindmarshRose, self).__init__(size=size, name=name)
super(HindmarshRose, self).__init__(size=size,
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (TrainingMode, NormalMode), self.__class__)

# parameters
self.a = init_param(a, self.num, allow_none=False)
self.b = init_param(b, self.num, allow_none=False)
self.c = init_param(c, self.num, allow_none=False)
self.d = init_param(d, self.num, allow_none=False)
self.r = init_param(r, self.num, allow_none=False)
self.s = init_param(s, self.num, allow_none=False)
self.V_th = init_param(V_th, self.num, allow_none=False)
self.V_rest = init_param(V_rest, self.num, allow_none=False)
self.a = parameter(a, self.varshape, allow_none=False)
self.b = parameter(b, self.varshape, allow_none=False)
self.c = parameter(c, self.varshape, allow_none=False)
self.d = parameter(d, self.varshape, allow_none=False)
self.r = parameter(r, self.varshape, allow_none=False)
self.s = parameter(s, self.varshape, allow_none=False)
self.V_th = parameter(V_th, self.varshape, allow_none=False)
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
self.noise = init_noise(noise, self.varshape, num_vars=3)
self.spike_fun = check_callable(spike_fun, 'spike_fun')

# variables
check_initializer(V_initializer, 'V_initializer', allow_none=False)
@@ -1209,21 +1708,26 @@ class HindmarshRose(NeuGroup):
self._z_initializer = z_initializer

# variables
self.z = bm.Variable(init_param(V_initializer, (self.num,)))
self.y = bm.Variable(init_param(y_initializer, (self.num,)))
self.V = bm.Variable(init_param(z_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.V = variable(self._V_initializer, mode, self.varshape)
self.y = variable(self._y_initializer, mode, self.varshape)
self.z = variable(self._z_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)

# integral
self.integral = odeint(method=method, f=self.derivative)
if self.noise is None:
self.integral = odeint(method=method, f=self.derivative)
else:
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

def reset(self):
self.V.value = init_param(self._V_initializer, (self.num,))
self.y.value = init_param(self._y_initializer, (self.num,))
self.z.value = init_param(self._z_initializer, (self.num,))
self.input[:] = 0
self.spike[:] = False
def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.y.value = variable(self._y_initializer, batch_size, self.varshape)
self.z.value = variable(self._z_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)

def dV(self, V, t, y, z, I_ext):
return y - self.a * V * V * V + self.b * V * V - z + I_ext
@@ -1238,12 +1742,19 @@ class HindmarshRose(NeuGroup):
def derivative(self):
return JointEq([self.dV, self.dy, self.dz])

def update(self, t, dt):
def update(self, tdi, x=None):
t, dt = tdi.t, tdi.dt
if x is not None: self.input += x
V, y, z = self.integral(self.V, self.y, self.z, t, self.input, dt=dt)
self.spike.value = bm.logical_and(V >= self.V_th, self.V < self.V_th)
if isinstance(self.mode, TrainingMode):
self.spike.value = self.spike_fun(V - self.V_th, self.V - self.V_th)
else:
self.spike.value = bm.logical_and(V >= self.V_th, self.V < self.V_th)
self.V.value = V
self.y.value = y
self.z.value = z

def clear_input(self):
self.input[:] = 0.


@@ -1333,23 +1844,35 @@ class FHN(NeuGroup):
def __init__(
self,
size: Shape,
a: Union[float, Tensor, Initializer, Callable] = 0.7,
b: Union[float, Tensor, Initializer, Callable] = 0.8,
tau: Union[float, Tensor, Initializer, Callable] = 12.5,
Vth: Union[float, Tensor, Initializer, Callable] = 1.8,
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
w_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
a: Union[float, Array, Initializer, Callable] = 0.7,
b: Union[float, Array, Initializer, Callable] = 0.8,
tau: Union[float, Array, Initializer, Callable] = 12.5,
Vth: Union[float, Array, Initializer, Callable] = 1.8,
V_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
w_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
noise: Union[float, Array, Initializer, Callable] = None,
method: str = 'exp_auto',
name: str = None
keep_size: bool = False,
name: str = None,

# parameters for training
mode: Mode = normal,
spike_fun: Callable = bm.spike2_with_sigmoid_grad,
):
# initialization
super(FHN, self).__init__(size=size, name=name)
super(FHN, self).__init__(size=size,
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (TrainingMode, NormalMode), self.__class__)

# parameters
self.a = init_param(a, self.num, allow_none=False)
self.b = init_param(b, self.num, allow_none=False)
self.tau = init_param(tau, self.num, allow_none=False)
self.Vth = init_param(Vth, self.num, allow_none=False)
self.a = parameter(a, self.varshape, allow_none=False)
self.b = parameter(b, self.varshape, allow_none=False)
self.tau = parameter(tau, self.varshape, allow_none=False)
self.Vth = parameter(Vth, self.varshape, allow_none=False)
self.noise = init_noise(noise, self.varshape, num_vars=2)
self.spike_fun = check_callable(spike_fun, 'spike_fun')

# initializers
check_initializer(V_initializer, 'V_initializer')
@@ -1358,19 +1881,24 @@ class FHN(NeuGroup):
self._w_initializer = w_initializer

# variables
self.w = bm.Variable(init_param(w_initializer, (self.num,)))
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.V = variable(self._V_initializer, mode, self.varshape)
self.w = variable(self._w_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)

# integral
self.integral = odeint(method=method, f=self.derivative)
if self.noise is None:
self.integral = odeint(method=method, f=self.derivative)
else:
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

def reset(self):
self.V.value = init_param(self._V_initializer, (self.num,))
self.w.value = init_param(self._w_initializer, (self.num,))
self.input[:] = 0
self.spike[:] = False
def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.w.value = variable(self._w_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)

def dV(self, V, t, w, I_ext):
return V - V * V * V / 3 - w + I_ext
@@ -1382,9 +1910,16 @@ class FHN(NeuGroup):
def derivative(self):
return JointEq([self.dV, self.dw])

def update(self, t, dt):
V, w = self.integral(self.V, self.w, t, self.input, dt=dt)
self.spike.value = bm.logical_and(V >= self.Vth, self.V < self.Vth)
def update(self, tdi, x=None):
t, dt = tdi.t, tdi.dt
if x is not None: self.input += x
V, w = self.integral(self.V.value, self.w.value, t, self.input.value, dt=dt)
if isinstance(self.mode, TrainingMode):
self.spike.value = self.spike_fun(V - self.Vth, self.V - self.Vth)
else:
self.spike.value = bm.logical_and(V >= self.Vth, self.V < self.Vth)
self.V.value = V
self.w.value = w

def clear_input(self):
self.input[:] = 0.

+ 17
- 0
brainpy/dyn/neurons/tests/test_reduced_models.py View File

@@ -0,0 +1,17 @@
# -*- coding: utf-8 -*-


import brainpy as bp
from absl.testing import parameterized
from brainpy.dyn.neurons import reduced_models


class TestNoise(parameterized.TestCase):
@parameterized.named_parameters(
{'testcase_name': f'noise_of_{name}', 'neuron': name}
for name in reduced_models.__all__
)
def test_noise(self, neuron):
model = getattr(reduced_models, neuron)(size=1, noise=0.1)
runner = bp.dyn.DSRunner(model, progress_bar=False)
runner.run(10.)

+ 0
- 3
brainpy/dyn/rates/__init__.py View File

@@ -1,6 +1,3 @@
# -*- coding: utf-8 -*-

from .noises import *
from .populations import *
from .couplings import *


+ 412
- 312
brainpy/dyn/rates/populations.py View File

@@ -2,22 +2,19 @@

from typing import Union, Callable

import numpy as np
from jax.experimental.host_callback import id_tap

import brainpy.math as bm
from brainpy import check
from brainpy import check, math as bm
from brainpy.dyn.base import NeuGroup
from brainpy.initialize import Initializer, Uniform, init_param, ZeroInit
from brainpy.integrators.dde import ddeint
from brainpy.dyn.neurons.noise_groups import OUProcess
from brainpy.initialize import Initializer, Uniform, parameter, variable, ZeroInit
from brainpy.integrators.joint_eq import JointEq
from brainpy.integrators.ode import odeint
from brainpy.modes import Mode, normal
from brainpy.tools.checking import check_float, check_initializer
from brainpy.types import Shape, Tensor
from .noises import OUProcess
from brainpy.tools.errors import check_error_in_jit
from brainpy.types import Shape, Array

__all__ = [
'Population',
'RateModel',
'FHN',
'FeedbackFHN',
'QIF',
@@ -27,12 +24,11 @@ __all__ = [
]


class Population(NeuGroup):
def update(self, t, dt):
raise NotImplementedError
class RateModel(NeuGroup):
pass


class FHN(NeuGroup):
class FHN(RateModel):
r"""FitzHugh-Nagumo system used in [1]_.

.. math::
@@ -70,49 +66,55 @@ class FHN(NeuGroup):
def __init__(
self,
size: Shape,
keep_size: bool = False,

# fhn parameters
alpha: Union[float, Tensor, Initializer, Callable] = 3.0,
beta: Union[float, Tensor, Initializer, Callable] = 4.0,
gamma: Union[float, Tensor, Initializer, Callable] = -1.5,
delta: Union[float, Tensor, Initializer, Callable] = 0.0,
epsilon: Union[float, Tensor, Initializer, Callable] = 0.5,
tau: Union[float, Tensor, Initializer, Callable] = 20.0,
alpha: Union[float, Array, Initializer, Callable] = 3.0,
beta: Union[float, Array, Initializer, Callable] = 4.0,
gamma: Union[float, Array, Initializer, Callable] = -1.5,
delta: Union[float, Array, Initializer, Callable] = 0.0,
epsilon: Union[float, Array, Initializer, Callable] = 0.5,
tau: Union[float, Array, Initializer, Callable] = 20.0,

# noise parameters
x_ou_mean: Union[float, Tensor, Initializer, Callable] = 0.0,
x_ou_sigma: Union[float, Tensor, Initializer, Callable] = 0.0,
x_ou_tau: Union[float, Tensor, Initializer, Callable] = 5.0,
y_ou_mean: Union[float, Tensor, Initializer, Callable] = 0.0,
y_ou_sigma: Union[float, Tensor, Initializer, Callable] = 0.0,
y_ou_tau: Union[float, Tensor, Initializer, Callable] = 5.0,
x_ou_mean: Union[float, Array, Initializer, Callable] = 0.0,
x_ou_sigma: Union[float, Array, Initializer, Callable] = 0.0,
x_ou_tau: Union[float, Array, Initializer, Callable] = 5.0,
y_ou_mean: Union[float, Array, Initializer, Callable] = 0.0,
y_ou_sigma: Union[float, Array, Initializer, Callable] = 0.0,
y_ou_tau: Union[float, Array, Initializer, Callable] = 5.0,

# other parameters
x_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.05),
y_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.05),
x_initializer: Union[Initializer, Callable, Array] = Uniform(0, 0.05),
y_initializer: Union[Initializer, Callable, Array] = Uniform(0, 0.05),
method: str = 'exp_auto',
sde_method: str = None,
name: str = None,

# parameter for training
mode: Mode = normal,
):
super(FHN, self).__init__(size=size, name=name)
super(FHN, self).__init__(size=size,
name=name,
keep_size=keep_size,
mode=mode)

# model parameters
self.alpha = init_param(alpha, self.num, allow_none=False)
self.beta = init_param(beta, self.num, allow_none=False)
self.gamma = init_param(gamma, self.num, allow_none=False)
self.delta = init_param(delta, self.num, allow_none=False)
self.epsilon = init_param(epsilon, self.num, allow_none=False)
self.tau = init_param(tau, self.num, allow_none=False)
self.alpha = parameter(alpha, self.varshape, allow_none=False)
self.beta = parameter(beta, self.varshape, allow_none=False)
self.gamma = parameter(gamma, self.varshape, allow_none=False)
self.delta = parameter(delta, self.varshape, allow_none=False)
self.epsilon = parameter(epsilon, self.varshape, allow_none=False)
self.tau = parameter(tau, self.varshape, allow_none=False)

# noise parameters
self.x_ou_mean = init_param(x_ou_mean, self.num, allow_none=False) # mV/ms, OU process
self.y_ou_mean = init_param(y_ou_mean, self.num, allow_none=False) # mV/ms, OU process
self.x_ou_sigma = init_param(x_ou_sigma, self.num, allow_none=False) # mV/ms/sqrt(ms), noise intensity
self.y_ou_sigma = init_param(y_ou_sigma, self.num, allow_none=False) # mV/ms/sqrt(ms), noise intensity
self.x_ou_tau = init_param(x_ou_tau, self.num,
allow_none=False) # ms, timescale of the Ornstein-Uhlenbeck noise process
self.y_ou_tau = init_param(y_ou_tau, self.num,
allow_none=False) # ms, timescale of the Ornstein-Uhlenbeck noise process
self.x_ou_mean = parameter(x_ou_mean, self.varshape, allow_none=False) # mV/ms, OU process
self.y_ou_mean = parameter(y_ou_mean, self.varshape, allow_none=False) # mV/ms, OU process
self.x_ou_sigma = parameter(x_ou_sigma, self.varshape, allow_none=False) # mV/ms/sqrt(ms), noise intensity
self.y_ou_sigma = parameter(y_ou_sigma, self.varshape, allow_none=False) # mV/ms/sqrt(ms), noise intensity
self.x_ou_tau = parameter(x_ou_tau, self.varshape,
allow_none=False) # ms, timescale of the Ornstein-Uhlenbeck noise process
self.y_ou_tau = parameter(y_ou_tau, self.varshape,
allow_none=False) # ms, timescale of the Ornstein-Uhlenbeck noise process

# initializers
check_initializer(x_initializer, 'x_initializer')
@@ -121,32 +123,38 @@ class FHN(NeuGroup):
self._y_initializer = y_initializer

# variables
self.x = bm.Variable(init_param(x_initializer, (self.num,)))
self.y = bm.Variable(init_param(y_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.x = variable(x_initializer, mode, self.varshape)
self.y = variable(y_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
self.input_y = variable(bm.zeros, mode, self.varshape)

# noise variables
self.x_ou = self.y_ou = None
if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.):
self.x_ou = OUProcess(self.num,
self.x_ou_mean, self.x_ou_sigma, self.x_ou_tau,
method=sde_method)
self.x_ou = OUProcess(self.varshape,
self.x_ou_mean,
self.x_ou_sigma,
self.x_ou_tau,
method=method)
if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.):
self.y_ou = OUProcess(self.num,
self.y_ou_mean, self.y_ou_sigma, self.y_ou_tau,
method=sde_method)
self.y_ou = OUProcess(self.varshape,
self.y_ou_mean,
self.y_ou_sigma,
self.y_ou_tau,
method=method)

# integral functions
self.integral = odeint(f=JointEq([self.dx, self.dy]), method=method)

def reset(self):
self.x.value = init_param(self._x_initializer, (self.num,))
self.y.value = init_param(self._y_initializer, (self.num,))
self.input[:] = 0
def reset_state(self, batch_size=None):
self.x.value = variable(self._x_initializer, batch_size, self.varshape)
self.y.value = variable(self._y_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
self.input_y.value = variable(bm.zeros, batch_size, self.varshape)
if self.x_ou is not None:
self.x_ou.reset()
self.x_ou.reset_state(batch_size)
if self.y_ou is not None:
self.y_ou.reset()
self.y_ou.reset_state(batch_size)

def dx(self, x, t, y, x_ext):
return - self.alpha * x ** 3 + self.beta * x ** 2 + self.gamma * x - y + x_ext
@@ -154,21 +162,29 @@ class FHN(NeuGroup):
def dy(self, y, t, x, y_ext=0.):
return (x - self.delta - self.epsilon * y) / self.tau + y_ext

def update(self, t, dt):
def update(self, tdi, x=None):
t, dt = tdi['t'], tdi['dt']

# input
if x is not None: self.input += x
if self.x_ou is not None:
self.input += self.x_ou.x
self.x_ou.update(t, dt)
y_ext = 0.
self.x_ou.update(tdi)
if self.y_ou is not None:
y_ext = self.y_ou.x
self.y_ou.update(t, dt)
x, y = self.integral(self.x, self.y, t, x_ext=self.input, y_ext=y_ext, dt=dt)
self.input_y += self.y_ou.x
self.y_ou.update(tdi)

# integral
x, y = self.integral(self.x, self.y, t, x_ext=self.input, y_ext=self.input_y, dt=dt)
self.x.value = x
self.y.value = y

def clear_input(self):
self.input[:] = 0.
self.input_y[:] = 0.


class FeedbackFHN(NeuGroup):
class FeedbackFHN(RateModel):
r"""FitzHugh-Nagumo model with recurrent neural feedback.

The equation of the feedback FitzHugh-Nagumo model [4]_ is given by
@@ -184,7 +200,7 @@ class FeedbackFHN(NeuGroup):
**Model Examples**

>>> import brainpy as bp
>>> fhn = bp.dyn.FeedbackFHN(1, delay=10.)
>>> fhn = bp.dyn.rates.FeedbackFHN(1, delay=10.)
>>> runner = bp.dyn.DSRunner(fhn, inputs=('input', 1.), monitors=['x', 'y'])
>>> runner.run(100.)
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.y, legend='y')
@@ -221,8 +237,6 @@ class FeedbackFHN(NeuGroup):
y_ou_tau: Parameter
The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms].



References
----------
.. [4] Plant, Richard E. (1981). *A FitzHugh Differential-Difference
@@ -234,52 +248,58 @@ class FeedbackFHN(NeuGroup):
def __init__(
self,
size: Shape,
keep_size: bool = False,

# model parameters
a: Union[float, Tensor, Initializer, Callable] = 0.7,
b: Union[float, Tensor, Initializer, Callable] = 0.8,
delay: Union[float, Tensor, Initializer, Callable] = 10.,
tau: Union[float, Tensor, Initializer, Callable] = 12.5,
mu: Union[float, Tensor, Initializer, Callable] = 1.6886,
v0: Union[float, Tensor, Initializer, Callable] = -1,
a: Union[float, Array, Initializer, Callable] = 0.7,
b: Union[float, Array, Initializer, Callable] = 0.8,
delay: Union[float, Array, Initializer, Callable] = 10.,
tau: Union[float, Array, Initializer, Callable] = 12.5,
mu: Union[float, Array, Initializer, Callable] = 1.6886,
v0: Union[float, Array, Initializer, Callable] = -1,

# noise parameters
x_ou_mean: Union[float, Tensor, Initializer, Callable] = 0.0,
x_ou_sigma: Union[float, Tensor, Initializer, Callable] = 0.0,
x_ou_tau: Union[float, Tensor, Initializer, Callable] = 5.0,
y_ou_mean: Union[float, Tensor, Initializer, Callable] = 0.0,
y_ou_sigma: Union[float, Tensor, Initializer, Callable] = 0.0,
y_ou_tau: Union[float, Tensor, Initializer, Callable] = 5.0,
x_ou_mean: Union[float, Array, Initializer, Callable] = 0.0,
x_ou_sigma: Union[float, Array, Initializer, Callable] = 0.0,
x_ou_tau: Union[float, Array, Initializer, Callable] = 5.0,
y_ou_mean: Union[float, Array, Initializer, Callable] = 0.0,
y_ou_sigma: Union[float, Array, Initializer, Callable] = 0.0,
y_ou_tau: Union[float, Array, Initializer, Callable] = 5.0,

# other parameters
x_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.05),
y_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.05),
method: str = 'rk4',
sde_method: str = None,
x_initializer: Union[Initializer, Callable, Array] = Uniform(0, 0.05),
y_initializer: Union[Initializer, Callable, Array] = Uniform(0, 0.05),
method: str = 'exp_auto',
name: str = None,
dt: float = None
dt: float = None,

# parameter for training
mode: Mode = normal,
):
super(FeedbackFHN, self).__init__(size=size, name=name)
super(FeedbackFHN, self).__init__(size=size,
name=name,
keep_size=keep_size,
mode=mode)

# dt
self.dt = bm.get_dt() if dt is None else dt
check_float(self.dt, 'dt', allow_none=False, min_bound=0., allow_int=False)

# parameters
self.a = init_param(a, self.num, allow_none=False)
self.b = init_param(b, self.num, allow_none=False)
self.delay = init_param(delay, self.num, allow_none=False)
self.tau = init_param(tau, self.num, allow_none=False)
self.mu = init_param(mu, self.num, allow_none=False) # feedback strength
self.v0 = init_param(v0, self.num, allow_none=False) # resting potential
self.a = parameter(a, self.varshape, allow_none=False)
self.b = parameter(b, self.varshape, allow_none=False)
self.delay = parameter(delay, self.varshape, allow_none=False)
self.tau = parameter(tau, self.varshape, allow_none=False)
self.mu = parameter(mu, self.varshape, allow_none=False) # feedback strength
self.v0 = parameter(v0, self.varshape, allow_none=False) # resting potential

# noise parameters
self.x_ou_mean = init_param(x_ou_mean, self.num, allow_none=False)
self.y_ou_mean = init_param(y_ou_mean, self.num, allow_none=False)
self.x_ou_sigma = init_param(x_ou_sigma, self.num, allow_none=False)
self.y_ou_sigma = init_param(y_ou_sigma, self.num, allow_none=False)
self.x_ou_tau = init_param(x_ou_tau, self.num, allow_none=False)
self.y_ou_tau = init_param(y_ou_tau, self.num, allow_none=False)
self.x_ou_mean = parameter(x_ou_mean, self.varshape, allow_none=False)
self.y_ou_mean = parameter(y_ou_mean, self.varshape, allow_none=False)
self.x_ou_sigma = parameter(x_ou_sigma, self.varshape, allow_none=False)
self.y_ou_sigma = parameter(y_ou_sigma, self.varshape, allow_none=False)
self.x_ou_tau = parameter(x_ou_tau, self.varshape, allow_none=False)
self.y_ou_tau = parameter(y_ou_tau, self.varshape, allow_none=False)

# initializers
check_initializer(x_initializer, 'x_initializer')
@@ -288,36 +308,42 @@ class FeedbackFHN(NeuGroup):
self._y_initializer = y_initializer

# variables
self.x = bm.Variable(init_param(x_initializer, (self.num,)))
self.y = bm.Variable(init_param(y_initializer, (self.num,)))
self.x = variable(x_initializer, mode, self.varshape)
self.y = variable(y_initializer, mode, self.varshape)
self.x_delay = bm.TimeDelay(self.x, self.delay, dt=self.dt, interp_method='round')
self.input = bm.Variable(bm.zeros(self.num))
self.input = variable(bm.zeros, mode, self.varshape)
self.input_y = variable(bm.zeros, mode, self.varshape)

# noise variables
self.x_ou = self.y_ou = None
if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.):
self.x_ou = OUProcess(self.num,
self.x_ou_mean, self.x_ou_sigma, self.x_ou_tau,
method=sde_method)
self.x_ou = OUProcess(self.varshape,
self.x_ou_mean,
self.x_ou_sigma,
self.x_ou_tau,
method=method)
if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.):
self.y_ou = OUProcess(self.num,
self.y_ou_mean, self.y_ou_sigma, self.y_ou_tau,
method=sde_method)
self.y_ou = OUProcess(self.varshape,
self.y_ou_mean,
self.y_ou_sigma,
self.y_ou_tau,
method=method)

# integral
self.integral = ddeint(method=method,
self.integral = odeint(method=method,
f=JointEq([self.dx, self.dy]),
state_delays={'V': self.x_delay})

def reset(self):
self.x.value = init_param(self._x_initializer, (self.num,))
self.y.value = init_param(self._y_initializer, (self.num,))
def reset_state(self, batch_size=None):
self.x.value = variable(self._x_initializer, batch_size, self.varshape)
self.y.value = variable(self._y_initializer, batch_size, self.varshape)
self.x_delay.reset(self.x, self.delay)
self.input[:] = 0
self.input = variable(bm.zeros, batch_size, self.varshape)
self.input_y = variable(bm.zeros, batch_size, self.varshape)
if self.x_ou is not None:
self.x_ou.reset()
self.x_ou.reset_state(batch_size)
if self.y_ou is not None:
self.y_ou.reset()
self.y_ou.reset_state(batch_size)

def dx(self, x, t, y, x_ext):
return x - x * x * x / 3 - y + x_ext + self.mu * (self.x_delay(t - self.delay) - self.v0)
@@ -325,29 +351,35 @@ class FeedbackFHN(NeuGroup):
def dy(self, y, t, x, y_ext):
return (x + self.a - self.b * y + y_ext) / self.tau

def _check_dt(self, dt, *args):
if np.absolute(dt - self.dt) > 1e-6:
raise ValueError(f'The "dt" {dt} used in model running is '
f'not consistent with the "dt" {self.dt} '
f'used in model definition.')
def _check_dt(self, dt):
raise ValueError(f'The "dt" {dt} used in model running is '
f'not consistent with the "dt" {self.dt} '
f'used in model definition.')

def update(self, t, dt):
def update(self, tdi, x=None):
t = tdi['t']
dt = tdi['dt']
if check.is_checking():
id_tap(self._check_dt, dt)
check_error_in_jit(not bm.isclose(dt, self.dt), self._check_dt, dt)

if x is not None: self.input += x
if self.x_ou is not None:
self.input += self.x_ou.x
self.x_ou.update(t, dt)
y_ext = 0.
self.x_ou.update(tdi)
if self.y_ou is not None:
y_ext = self.y_ou.x
self.y_ou.update(t, dt)
x, y = self.integral(self.x, self.y, t, x_ext=self.input, y_ext=y_ext, dt=dt)
self.input_y += self.y_ou.x
self.y_ou.update(tdi)

x, y = self.integral(self.x, self.y, t, x_ext=self.input, y_ext=self.input_y, dt=dt)
self.x.value = x
self.y.value = y

def clear_input(self):
self.input[:] = 0.
self.input_y[:] = 0.


class QIF(NeuGroup):
class QIF(RateModel):
r"""A mean-field model of a quadratic integrate-and-fire neuron population.

**Model Descriptions**
@@ -416,46 +448,52 @@ class QIF(NeuGroup):
def __init__(
self,
size: Shape,
keep_size: bool = False,

# model parameters
tau: Union[float, Tensor, Initializer, Callable] = 1.,
eta: Union[float, Tensor, Initializer, Callable] = -5.0,
delta: Union[float, Tensor, Initializer, Callable] = 1.0,
J: Union[float, Tensor, Initializer, Callable] = 15.,
tau: Union[float, Array, Initializer, Callable] = 1.,
eta: Union[float, Array, Initializer, Callable] = -5.0,
delta: Union[float, Array, Initializer, Callable] = 1.0,
J: Union[float, Array, Initializer, Callable] = 15.,

# noise parameters
x_ou_mean: Union[float, Tensor, Initializer, Callable] = 0.0,
x_ou_sigma: Union[float, Tensor, Initializer, Callable] = 0.0,
x_ou_tau: Union[float, Tensor, Initializer, Callable] = 5.0,
y_ou_mean: Union[float, Tensor, Initializer, Callable] = 0.0,
y_ou_sigma: Union[float, Tensor, Initializer, Callable] = 0.0,
y_ou_tau: Union[float, Tensor, Initializer, Callable] = 5.0,
x_ou_mean: Union[float, Array, Initializer, Callable] = 0.0,
x_ou_sigma: Union[float, Array, Initializer, Callable] = 0.0,
x_ou_tau: Union[float, Array, Initializer, Callable] = 5.0,
y_ou_mean: Union[float, Array, Initializer, Callable] = 0.0,
y_ou_sigma: Union[float, Array, Initializer, Callable] = 0.0,
y_ou_tau: Union[float, Array, Initializer, Callable] = 5.0,

# other parameters
x_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.05),
y_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.05),
x_initializer: Union[Initializer, Callable, Array] = Uniform(0, 0.05),
y_initializer: Union[Initializer, Callable, Array] = Uniform(0, 0.05),
method: str = 'exp_auto',
name: str = None,
sde_method: str = None,

# parameter for training
mode: Mode = normal,
):
super(QIF, self).__init__(size=size, name=name)
super(QIF, self).__init__(size=size,
name=name,
keep_size=keep_size,
mode=mode)

# parameters
self.tau = init_param(tau, self.num, allow_none=False)
self.tau = parameter(tau, self.varshape, allow_none=False)
# the mean of a Lorenzian distribution over the neural excitability in the population
self.eta = init_param(eta, self.num, allow_none=False)
self.eta = parameter(eta, self.varshape, allow_none=False)
# the half-width at half maximum of the Lorenzian distribution over the neural excitability
self.delta = init_param(delta, self.num, allow_none=False)
self.delta = parameter(delta, self.varshape, allow_none=False)
# the strength of the recurrent coupling inside the population
self.J = init_param(J, self.num, allow_none=False)
self.J = parameter(J, self.varshape, allow_none=False)

# noise parameters
self.x_ou_mean = init_param(x_ou_mean, self.num, allow_none=False)
self.y_ou_mean = init_param(y_ou_mean, self.num, allow_none=False)
self.x_ou_sigma = init_param(x_ou_sigma, self.num, allow_none=False)
self.y_ou_sigma = init_param(y_ou_sigma, self.num, allow_none=False)
self.x_ou_tau = init_param(x_ou_tau, self.num, allow_none=False)
self.y_ou_tau = init_param(y_ou_tau, self.num, allow_none=False)
self.x_ou_mean = parameter(x_ou_mean, self.varshape, allow_none=False)
self.y_ou_mean = parameter(y_ou_mean, self.varshape, allow_none=False)
self.x_ou_sigma = parameter(x_ou_sigma, self.varshape, allow_none=False)
self.y_ou_sigma = parameter(y_ou_sigma, self.varshape, allow_none=False)
self.x_ou_tau = parameter(x_ou_tau, self.varshape, allow_none=False)
self.y_ou_tau = parameter(y_ou_tau, self.varshape, allow_none=False)

# initializers
check_initializer(x_initializer, 'x_initializer')
@@ -464,32 +502,38 @@ class QIF(NeuGroup):
self._y_initializer = y_initializer

# variables
self.x = bm.Variable(init_param(x_initializer, (self.num,)))
self.y = bm.Variable(init_param(y_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.x = variable(x_initializer, mode, self.varshape)
self.y = variable(y_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
self.input_y = variable(bm.zeros, mode, self.varshape)

# noise variables
self.x_ou = self.y_ou = None
if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.):
self.x_ou = OUProcess(self.num,
self.x_ou_mean, self.x_ou_sigma, self.x_ou_tau,
method=sde_method)
self.x_ou = OUProcess(self.varshape,
self.x_ou_mean,
self.x_ou_sigma,
self.x_ou_tau,
method=method)
if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.):
self.y_ou = OUProcess(self.num,
self.y_ou_mean, self.y_ou_sigma, self.y_ou_tau,
method=sde_method)
self.y_ou = OUProcess(self.varshape,
self.y_ou_mean,
self.y_ou_sigma,
self.y_ou_tau,
method=method)

# functions
self.integral = odeint(JointEq([self.dx, self.dy]), method=method)

def reset(self):
self.x.value = init_param(self._x_initializer, (self.num,))
self.y.value = init_param(self._y_initializer, (self.num,))
self.input[:] = 0
def reset_state(self, batch_size=None):
self.x.value = variable(self._x_initializer, batch_size, self.varshape)
self.y.value = variable(self._y_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
self.input_y.value = variable(bm.zeros, batch_size, self.varshape)
if self.x_ou is not None:
self.x_ou.reset()
self.x_ou.reset_state(batch_size)
if self.y_ou is not None:
self.y_ou.reset()
self.y_ou.reset_state(batch_size)

def dy(self, y, t, x, y_ext):
return (self.delta / (bm.pi * self.tau) + 2. * x * y + y_ext) / self.tau
@@ -498,21 +542,27 @@ class QIF(NeuGroup):
return (x ** 2 + self.eta + x_ext + self.J * y * self.tau -
(bm.pi * y * self.tau) ** 2) / self.tau

def update(self, t, dt):
def update(self, tdi, x=None):
t, dt = tdi['t'], tdi['dt']

if x is not None: self.input += x
if self.x_ou is not None:
self.input += self.x_ou.x
self.x_ou.update(t, dt)
y_ext = 0.
self.x_ou.update(tdi)
if self.y_ou is not None:
y_ext = self.y_ou.x
self.y_ou.update(t, dt)
x, y = self.integral(self.x, self.y, t=t, x_ext=self.input, y_ext=y_ext, dt=dt)
self.input_y += self.y_ou.x
self.y_ou.update(tdi)

x, y = self.integral(self.x, self.y, t=t, x_ext=self.input, y_ext=self.input_y, dt=dt)
self.x.value = x
self.y.value = y

def clear_input(self):
self.input[:] = 0.
self.input_y[:] = 0.


class StuartLandauOscillator(Population):
class StuartLandauOscillator(RateModel):
r"""
Stuart-Landau model with Hopf bifurcation.

@@ -541,40 +591,45 @@ class StuartLandauOscillator(Population):
def __init__(
self,
size: Shape,
keep_size: bool = False,

# model parameters
a: Union[float, Tensor, Initializer, Callable] = 0.25,
w: Union[float, Tensor, Initializer, Callable] = 0.2,
a: Union[float, Array, Initializer, Callable] = 0.25,
w: Union[float, Array, Initializer, Callable] = 0.2,

# noise parameters
x_ou_mean: Union[float, Tensor, Initializer, Callable] = 0.0,
x_ou_sigma: Union[float, Tensor, Initializer, Callable] = 0.0,
x_ou_tau: Union[float, Tensor, Initializer, Callable] = 5.0,
y_ou_mean: Union[float, Tensor, Initializer, Callable] = 0.0,
y_ou_sigma: Union[float, Tensor, Initializer, Callable] = 0.0,
y_ou_tau: Union[float, Tensor, Initializer, Callable] = 5.0,
x_ou_mean: Union[float, Array, Initializer, Callable] = 0.0,
x_ou_sigma: Union[float, Array, Initializer, Callable] = 0.0,
x_ou_tau: Union[float, Array, Initializer, Callable] = 5.0,
y_ou_mean: Union[float, Array, Initializer, Callable] = 0.0,
y_ou_sigma: Union[float, Array, Initializer, Callable] = 0.0,
y_ou_tau: Union[float, Array, Initializer, Callable] = 5.0,

# other parameters
x_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.5),
y_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.5),
x_initializer: Union[Initializer, Callable, Array] = Uniform(0, 0.5),
y_initializer: Union[Initializer, Callable, Array] = Uniform(0, 0.5),
method: str = 'exp_auto',
sde_method: str = None,
name: str = None,

# parameter for training
mode: Mode = normal,
):
super(StuartLandauOscillator, self).__init__(size=size,
name=name)
name=name,
keep_size=keep_size,
mode=mode)

# model parameters
self.a = init_param(a, self.num, allow_none=False)
self.w = init_param(w, self.num, allow_none=False)
self.a = parameter(a, self.varshape, allow_none=False)
self.w = parameter(w, self.varshape, allow_none=False)

# noise parameters
self.x_ou_mean = init_param(x_ou_mean, self.num, allow_none=False)
self.y_ou_mean = init_param(y_ou_mean, self.num, allow_none=False)
self.x_ou_sigma = init_param(x_ou_sigma, self.num, allow_none=False)
self.y_ou_sigma = init_param(y_ou_sigma, self.num, allow_none=False)
self.x_ou_tau = init_param(x_ou_tau, self.num, allow_none=False)
self.y_ou_tau = init_param(y_ou_tau, self.num, allow_none=False)
self.x_ou_mean = parameter(x_ou_mean, self.varshape, allow_none=False)
self.y_ou_mean = parameter(y_ou_mean, self.varshape, allow_none=False)
self.x_ou_sigma = parameter(x_ou_sigma, self.varshape, allow_none=False)
self.y_ou_sigma = parameter(y_ou_sigma, self.varshape, allow_none=False)
self.x_ou_tau = parameter(x_ou_tau, self.varshape, allow_none=False)
self.y_ou_tau = parameter(y_ou_tau, self.varshape, allow_none=False)

# initializers
check_initializer(x_initializer, 'x_initializer')
@@ -583,32 +638,38 @@ class StuartLandauOscillator(Population):
self._y_initializer = y_initializer

# variables
self.x = bm.Variable(init_param(x_initializer, (self.num,)))
self.y = bm.Variable(init_param(y_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.x = variable(x_initializer, mode, self.varshape)
self.y = variable(y_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
self.input_y = variable(bm.zeros, mode, self.varshape)

# noise variables
self.x_ou = self.y_ou = None
if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.):
self.x_ou = OUProcess(self.num,
self.x_ou_mean, self.x_ou_sigma, self.x_ou_tau,
method=sde_method)
self.x_ou = OUProcess(self.varshape,
self.x_ou_mean,
self.x_ou_sigma,
self.x_ou_tau,
method=method)
if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.):
self.y_ou = OUProcess(self.num,
self.y_ou_mean, self.y_ou_sigma, self.y_ou_tau,
method=sde_method)
self.y_ou = OUProcess(self.varshape,
self.y_ou_mean,
self.y_ou_sigma,
self.y_ou_tau,
method=method)

# integral functions
self.integral = odeint(f=JointEq([self.dx, self.dy]), method=method)

def reset(self):
self.x.value = init_param(self._x_initializer, (self.num,))
self.y.value = init_param(self._y_initializer, (self.num,))
self.input[:] = 0
def reset_state(self, batch_size=None):
self.x.value = variable(self._x_initializer, batch_size, self.varshape)
self.y.value = variable(self._y_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
self.input_y.value = variable(bm.zeros, batch_size, self.varshape)
if self.x_ou is not None:
self.x_ou.reset()
self.x_ou.reset_state(batch_size)
if self.y_ou is not None:
self.y_ou.reset()
self.y_ou.reset_state(batch_size)

def dx(self, x, t, y, x_ext, a, w):
return (a - x * x - y * y) * x - w * y + x_ext
@@ -616,22 +677,34 @@ class StuartLandauOscillator(Population):
def dy(self, y, t, x, y_ext, a, w):
return (a - x * x - y * y) * y - w * y + y_ext

def update(self, t, dt):
def update(self, tdi, x=None):
t, dt = tdi['t'], tdi['dt']

if x is not None: self.input += x
if self.x_ou is not None:
self.input += self.x_ou.x
self.x_ou.update(t, dt)
y_ext = 0.
self.x_ou.update(tdi)
if self.y_ou is not None:
y_ext = self.y_ou.x
self.y_ou.update(t, dt)
x, y = self.integral(self.x, self.y, t, x_ext=self.input,
y_ext=y_ext, a=self.a, w=self.w, dt=dt)
self.input_y += self.y_ou.x
self.y_ou.update(tdi)

x, y = self.integral(self.x,
self.y,
t=t,
x_ext=self.input,
y_ext=self.input_y,
a=self.a,
w=self.w,
dt=dt)
self.x.value = x
self.y.value = y

def clear_input(self):
self.input[:] = 0.
self.input_y[:] = 0.


class WilsonCowanModel(Population):
class WilsonCowanModel(RateModel):
"""Wilson-Cowan population model.


@@ -656,65 +729,68 @@ class WilsonCowanModel(Population):
def __init__(
self,
size: Shape,
keep_size: bool = False,

# Excitatory parameters
E_tau: Union[float, Tensor, Initializer, Callable] = 1., # excitatory time constant
E_a: Union[float, Tensor, Initializer, Callable] = 1.2, # excitatory gain
E_theta: Union[float, Tensor, Initializer, Callable] = 2.8, # excitatory firing threshold
E_tau: Union[float, Array, Initializer, Callable] = 1., # excitatory time constant
E_a: Union[float, Array, Initializer, Callable] = 1.2, # excitatory gain
E_theta: Union[float, Array, Initializer, Callable] = 2.8, # excitatory firing threshold

# Inhibitory parameters
I_tau: Union[float, Tensor, Initializer, Callable] = 1., # inhibitory time constant
I_a: Union[float, Tensor, Initializer, Callable] = 1., # inhibitory gain
I_theta: Union[float, Tensor, Initializer, Callable] = 4.0, # inhibitory firing threshold
I_tau: Union[float, Array, Initializer, Callable] = 1., # inhibitory time constant
I_a: Union[float, Array, Initializer, Callable] = 1., # inhibitory gain
I_theta: Union[float, Array, Initializer, Callable] = 4.0, # inhibitory firing threshold

# connection parameters
wEE: Union[float, Tensor, Initializer, Callable] = 12., # local E-E coupling
wIE: Union[float, Tensor, Initializer, Callable] = 4., # local E-I coupling
wEI: Union[float, Tensor, Initializer, Callable] = 13., # local I-E coupling
wII: Union[float, Tensor, Initializer, Callable] = 11., # local I-I coupling
wEE: Union[float, Array, Initializer, Callable] = 12., # local E-E coupling
wIE: Union[float, Array, Initializer, Callable] = 4., # local E-I coupling
wEI: Union[float, Array, Initializer, Callable] = 13., # local I-E coupling
wII: Union[float, Array, Initializer, Callable] = 11., # local I-I coupling

# Refractory parameter
r: Union[float, Tensor, Initializer, Callable] = 1.,
r: Union[float, Array, Initializer, Callable] = 1.,

# noise parameters
x_ou_mean: Union[float, Tensor, Initializer, Callable] = 0.0,
x_ou_sigma: Union[float, Tensor, Initializer, Callable] = 0.0,
x_ou_tau: Union[float, Tensor, Initializer, Callable] = 5.0,
y_ou_mean: Union[float, Tensor, Initializer, Callable] = 0.0,
y_ou_sigma: Union[float, Tensor, Initializer, Callable] = 0.0,
y_ou_tau: Union[float, Tensor, Initializer, Callable] = 5.0,
x_ou_mean: Union[float, Array, Initializer, Callable] = 0.0,
x_ou_sigma: Union[float, Array, Initializer, Callable] = 0.0,
x_ou_tau: Union[float, Array, Initializer, Callable] = 5.0,
y_ou_mean: Union[float, Array, Initializer, Callable] = 0.0,
y_ou_sigma: Union[float, Array, Initializer, Callable] = 0.0,
y_ou_tau: Union[float, Array, Initializer, Callable] = 5.0,

# state initializer
x_initializer: Union[Initializer, Callable, Tensor] = Uniform(max_val=0.05),
y_initializer: Union[Initializer, Callable, Tensor] = Uniform(max_val=0.05),
x_initializer: Union[Initializer, Callable, Array] = Uniform(max_val=0.05),
y_initializer: Union[Initializer, Callable, Array] = Uniform(max_val=0.05),

# other parameters
sde_method: str = None,
method: str = 'exp_euler_auto',
name: str = None,

# parameter for training
mode: Mode = normal,
):
super(WilsonCowanModel, self).__init__(size=size, name=name)
super(WilsonCowanModel, self).__init__(size=size, name=name, keep_size=keep_size)

# model parameters
self.E_a = init_param(E_a, self.num, allow_none=False)
self.I_a = init_param(I_a, self.num, allow_none=False)
self.E_tau = init_param(E_tau, self.num, allow_none=False)
self.I_tau = init_param(I_tau, self.num, allow_none=False)
self.E_theta = init_param(E_theta, self.num, allow_none=False)
self.I_theta = init_param(I_theta, self.num, allow_none=False)
self.wEE = init_param(wEE, self.num, allow_none=False)
self.wIE = init_param(wIE, self.num, allow_none=False)
self.wEI = init_param(wEI, self.num, allow_none=False)
self.wII = init_param(wII, self.num, allow_none=False)
self.r = init_param(r, self.num, allow_none=False)
self.E_a = parameter(E_a, self.varshape, allow_none=False)
self.I_a = parameter(I_a, self.varshape, allow_none=False)
self.E_tau = parameter(E_tau, self.varshape, allow_none=False)
self.I_tau = parameter(I_tau, self.varshape, allow_none=False)
self.E_theta = parameter(E_theta, self.varshape, allow_none=False)
self.I_theta = parameter(I_theta, self.varshape, allow_none=False)
self.wEE = parameter(wEE, self.varshape, allow_none=False)
self.wIE = parameter(wIE, self.varshape, allow_none=False)
self.wEI = parameter(wEI, self.varshape, allow_none=False)
self.wII = parameter(wII, self.varshape, allow_none=False)
self.r = parameter(r, self.varshape, allow_none=False)

# noise parameters
self.x_ou_mean = init_param(x_ou_mean, self.num, allow_none=False)
self.y_ou_mean = init_param(y_ou_mean, self.num, allow_none=False)
self.x_ou_sigma = init_param(x_ou_sigma, self.num, allow_none=False)
self.y_ou_sigma = init_param(y_ou_sigma, self.num, allow_none=False)
self.x_ou_tau = init_param(x_ou_tau, self.num, allow_none=False)
self.y_ou_tau = init_param(y_ou_tau, self.num, allow_none=False)
self.x_ou_mean = parameter(x_ou_mean, self.varshape, allow_none=False)
self.y_ou_mean = parameter(y_ou_mean, self.varshape, allow_none=False)
self.x_ou_sigma = parameter(x_ou_sigma, self.varshape, allow_none=False)
self.y_ou_sigma = parameter(y_ou_sigma, self.varshape, allow_none=False)
self.x_ou_tau = parameter(x_ou_tau, self.varshape, allow_none=False)
self.y_ou_tau = parameter(y_ou_tau, self.varshape, allow_none=False)

# initializers
check_initializer(x_initializer, 'x_initializer')
@@ -723,32 +799,38 @@ class WilsonCowanModel(Population):
self._y_initializer = y_initializer

# variables
self.x = bm.Variable(init_param(x_initializer, (self.num,)))
self.y = bm.Variable(init_param(y_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.x = variable(x_initializer, mode, self.varshape)
self.y = variable(y_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
self.input_y = variable(bm.zeros, mode, self.varshape)

# noise variables
self.x_ou = self.y_ou = None
if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.):
self.x_ou = OUProcess(self.num,
self.x_ou_mean, self.x_ou_sigma, self.x_ou_tau,
method=sde_method)
self.x_ou = OUProcess(self.varshape,
self.x_ou_mean,
self.x_ou_sigma,
self.x_ou_tau,
method=method)
if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.):
self.y_ou = OUProcess(self.num,
self.y_ou_mean, self.y_ou_sigma, self.y_ou_tau,
method=sde_method)
self.y_ou = OUProcess(self.varshape,
self.y_ou_mean,
self.y_ou_sigma,
self.y_ou_tau,
method=method)

# functions
self.integral = odeint(f=JointEq([self.dx, self.dy]), method=method)

def reset(self):
self.x.value = init_param(self._x_initializer, (self.num,))
self.y.value = init_param(self._y_initializer, (self.num,))
self.input[:] = 0
def reset_state(self, batch_size=None):
self.x.value = variable(self._x_initializer, batch_size, self.varshape)
self.y.value = variable(self._y_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
self.input_y.value = variable(bm.zeros, batch_size, self.varshape)
if self.x_ou is not None:
self.x_ou.reset()
self.x_ou.reset_state(batch_size)
if self.y_ou is not None:
self.y_ou.reset()
self.y_ou.reset_state(batch_size)

def F(self, x, a, theta):
return 1 / (1 + bm.exp(-a * (x - theta))) - 1 / (1 + bm.exp(a * theta))
@@ -761,41 +843,45 @@ class WilsonCowanModel(Population):
x = self.wEI * x - self.wII * y + y_ext
return (-y + (1 - self.r * y) * self.F(x, self.I_a, self.I_theta)) / self.I_tau

def update(self, t, dt):
def update(self, tdi, x=None):
t, dt = tdi['t'], tdi['dt']
if x is not None: self.input += x
if self.x_ou is not None:
self.input += self.x_ou.x
self.x_ou.update(t, dt)
y_ext = 0.
self.x_ou.update(tdi)
if self.y_ou is not None:
y_ext = self.y_ou.x
self.y_ou.update(t, dt)
x, y = self.integral(self.x, self.y, t, x_ext=self.input, y_ext=y_ext, dt=dt)
self.input_y += self.y_ou.x
self.y_ou.update(tdi)
x, y = self.integral(self.x, self.y, t, x_ext=self.input, y_ext=self.input_y, dt=dt)
self.x.value = x
self.y.value = y

def clear_input(self):
self.input[:] = 0.
self.input_y[:] = 0.


class JansenRitModel(Population):
class JansenRitModel(RateModel):
pass


class KuramotoOscillator(Population):
class KuramotoOscillator(RateModel):
pass


class ThetaNeuron(Population):
class ThetaNeuron(RateModel):
pass


class RateQIFWithSFA(Population):
class RateQIFWithSFA(RateModel):
pass


class VanDerPolOscillator(Population):
class VanDerPolOscillator(RateModel):
pass


class ThresholdLinearModel(Population):
class ThresholdLinearModel(RateModel):
r"""A threshold linear rate model.

The threshold linear rate model is given by [1]_
@@ -824,57 +910,71 @@ class ThresholdLinearModel(Population):
def __init__(
self,
size: Shape,
tau_e: Union[float, Callable, Initializer, Tensor] = 2e-2,
tau_i: Union[float, Callable, Initializer, Tensor] = 1e-2,
beta_e: Union[float, Callable, Initializer, Tensor] = .066,
beta_i: Union[float, Callable, Initializer, Tensor] = .351,
noise_e: Union[float, Callable, Initializer, Tensor] = 0.,
noise_i: Union[float, Callable, Initializer, Tensor] = 0.,
e_initializer: Union[Tensor, Callable, Initializer] = ZeroInit(),
i_initializer: Union[Tensor, Callable, Initializer] = ZeroInit(),
tau_e: Union[float, Callable, Initializer, Array] = 2e-2,
tau_i: Union[float, Callable, Initializer, Array] = 1e-2,
beta_e: Union[float, Callable, Initializer, Array] = .066,
beta_i: Union[float, Callable, Initializer, Array] = .351,
noise_e: Union[float, Callable, Initializer, Array] = 0.,
noise_i: Union[float, Callable, Initializer, Array] = 0.,
e_initializer: Union[Array, Callable, Initializer] = ZeroInit(),
i_initializer: Union[Array, Callable, Initializer] = ZeroInit(),
seed: int = None,
name: str = None
keep_size: bool = False,
name: str = None,

# parameter for training
mode: Mode = normal,
):
super(ThresholdLinearModel, self).__init__(size, name=name)
super(ThresholdLinearModel, self).__init__(size,
name=name,
keep_size=keep_size,
mode=mode)

# parameters
self.seed = seed
self.tau_e = init_param(tau_e, self.num, False)
self.tau_i = init_param(tau_i, self.num, False)
self.beta_e = init_param(beta_e, self.num, False)
self.beta_i = init_param(beta_i, self.num, False)
self.noise_e = init_param(noise_e, self.num, False)
self.noise_i = init_param(noise_i, self.num, False)
self.tau_e = parameter(tau_e, self.varshape, False)
self.tau_i = parameter(tau_i, self.varshape, False)
self.beta_e = parameter(beta_e, self.varshape, False)
self.beta_i = parameter(beta_i, self.varshape, False)
self.noise_e = parameter(noise_e, self.varshape, False)
self.noise_i = parameter(noise_i, self.varshape, False)
self._e_initializer = e_initializer
self._i_initializer = i_initializer

# variables
self.e = bm.Variable(init_param(e_initializer, self.num)) # Firing rate of excitatory population
self.i = bm.Variable(init_param(i_initializer, self.num)) # Firing rate of inhibitory population
self.Ie = bm.Variable(bm.zeros(self.num)) # Input of excitaory population
self.Ii = bm.Variable(bm.zeros(self.num)) # Input of inhibitory population
self.e = variable(e_initializer, mode, self.varshape) # Firing rate of excitatory population
self.i = variable(i_initializer, mode, self.varshape) # Firing rate of inhibitory population
self.Ie = variable(bm.zeros, mode, self.varshape) # Input of excitaory population
self.Ii = variable(bm.zeros, mode, self.varshape) # Input of inhibitory population
if bm.any(self.noise_e != 0) or bm.any(self.noise_i != 0):
self.rng = bm.random.RandomState(self.seed)

def reset(self):
def reset(self, batch_size=None):
self.rng.seed(self.seed)
self.e.value = init_param(self._e_initializer, self.num)
self.i.value = init_param(self._i_initializer, self.num)
self.Ie[:] = 0.
self.Ii[:] = 0.
self.reset_state(batch_size)

def reset_state(self, batch_size=None):
self.e.value = variable(self._e_initializer, batch_size, self.varshape)
self.i.value = variable(self._i_initializer, batch_size, self.varshape)
self.Ie.value = variable(bm.zeros, batch_size, self.varshape)
self.Ii.value = variable(bm.zeros, batch_size, self.varshape)

def update(self, tdi, x=None):
t, dt = tdi['t'], tdi['dt']

def update(self, t, dt):
if x is not None: self.Ie += x
de = -self.e + self.beta_e * bm.maximum(self.Ie, 0.)
if bm.any(self.noise_e != 0.):
de += self.rng.randn(self.num) * self.noise_e
de += self.rng.randn(self.varshape) * self.noise_e
de = de / self.tau_e
self.e.value = bm.maximum(self.e + de * dt, 0.)

di = -self.i + self.beta_i * bm.maximum(self.Ii, 0.)
if bm.any(self.noise_i != 0.):
di += self.rng.randn(self.num) * self.noise_i
di += self.rng.randn(self.varshape) * self.noise_i
di = di / self.tau_i
self.i.value = bm.maximum(self.i + di * dt, 0.)

def clear_input(self):
self.Ie[:] = 0.
self.Ii[:] = 0.



+ 548
- 228
brainpy/dyn/runners.py View File

@@ -1,23 +1,242 @@
# -*- coding: utf-8 -*-

import time
from collections.abc import Iterable
from typing import Dict, Union, Sequence, Callable

import jax
import jax.numpy as jnp
import numpy as np
import tqdm.auto
from jax.experimental.host_callback import id_tap
from jax.tree_util import tree_map, tree_flatten

from brainpy.base.base import TensorCollector
from brainpy import math as bm
from brainpy.dyn import utils
from brainpy.dyn.base import DynamicalSystem
from brainpy.errors import RunningError
from brainpy.running.runner import Runner
from brainpy.tools.checking import check_float, serialize_kwargs
from brainpy.tools.others.dicts import DotDict
from brainpy.types import Array, Output, Monitor

__all__ = [
'DSRunner', 'ReportRunner', 'StructRunner',
'DSRunner',
]

SUPPORTED_INPUT_OPS = ['-', '+', '*', '/', '=']
SUPPORTED_INPUT_TYPE = ['fix', 'iter', 'func']


def check_and_format_inputs(host, inputs):
"""Check inputs and get the formatted inputs for the given population.

Parameters
----------
host : DynamicalSystem
The host which contains all data.
inputs : tuple, list
The inputs of the population.

Returns
-------
formatted_inputs : tuple, list
The formatted inputs of the population.
"""

# 1. check inputs
# ---------
if inputs is None:
inputs = []
if not isinstance(inputs, (tuple, list)):
raise RunningError('"inputs" must be a tuple/list.')
if len(inputs) > 0 and not isinstance(inputs[0], (list, tuple)):
if isinstance(inputs[0], (str, bm.Variable)):
inputs = [inputs]
else:
raise RunningError('Unknown input structure, only support inputs '
'with format of "(target, value, [type, operation])".')
for one_input in inputs:
if not 2 <= len(one_input) <= 4:
raise RunningError('For each target, you must specify '
'"(target, value, [type, operation])".')
if len(one_input) == 3 and one_input[2] not in SUPPORTED_INPUT_TYPE:
raise RunningError(f'Input type only supports '
f'"{SUPPORTED_INPUT_TYPE}", '
f'not "{one_input[2]}".')
if len(one_input) == 4 and one_input[3] not in SUPPORTED_INPUT_OPS:
raise RunningError(f'Input operation only supports '
f'"{SUPPORTED_INPUT_OPS}", '
f'not "{one_input[3]}".')

# 2. get targets and attributes
# ---------
inputs_which_found_target = []
inputs_not_found_target = []

# checking 1: absolute access
# Check whether the input target node is accessible,
# and check whether the target node has the attribute
nodes = None
for one_input in inputs:
key = one_input[0]
if isinstance(key, bm.Variable):
real_target = key
elif isinstance(key, str):
if nodes is None:
nodes = host.nodes(method='absolute', level=-1, include_self=True)
splits = key.split('.')
target = '.'.join(splits[:-1])
key = splits[-1]
if target == '':
real_target = host
else:
if target not in nodes:
inputs_not_found_target.append(one_input)
continue
real_target = nodes[target]
if not hasattr(real_target, key):
raise RunningError(f'Input target key "{key}" is not defined in {real_target}.')
real_target = getattr(real_target, key)
else:
raise RunningError(f'For each input, input[0] must be a string to '
f'specify variable of the target, but we got {key}.')
inputs_which_found_target.append((real_target,) + tuple(one_input[1:]))

# checking 2: relative access
# Check whether the input target node is accessible
# and check whether the target node has the attribute
if len(inputs_not_found_target):
nodes = host.nodes(method='relative', level=-1, include_self=True)
for one_input in inputs_not_found_target:
splits = one_input[0].split('.')
target, key = '.'.join(splits[:-1]), splits[-1]
if target not in nodes:
raise RunningError(f'Input target "{target}" is not defined in {host}.')
real_target = nodes[target]
if not hasattr(real_target, key):
raise RunningError(f'Input target key "{key}" is not defined in {real_target}.')
real_target = getattr(real_target, key)
inputs_which_found_target.append((real_target,) + tuple(one_input[1:]))

# 3. format inputs
# ---------
formatted_inputs = []
for one_input in inputs_which_found_target:
# input value
data_value = one_input[1]

# input type
if len(one_input) >= 3:
if one_input[2] == 'iter':
if not isinstance(data_value, Iterable):
raise ValueError(f'Input "{data_value}" for "{one_input[0]}" \n'
f'is set to be "iter" type, however we got the value with '
f'the type of {type(data_value)}')
elif one_input[2] == 'func':
if not callable(data_value):
raise ValueError(f'Input "{data_value}" for "{one_input[0]}" \n'
f'is set to be "func" type, however we got the value with '
f'the type of {type(data_value)}')
elif one_input[2] != 'fix':
raise RunningError(f'Only support {SUPPORTED_INPUT_TYPE} input type, but '
f'we got "{one_input[2]}"')

data_type = one_input[2]
else:
data_type = 'fix'

# operation
if len(one_input) == 4:
data_op = one_input[3]
else:
data_op = '+'
if data_op not in SUPPORTED_INPUT_OPS:
raise RunningError(f'Only support {SUPPORTED_INPUT_OPS}, while we got '
f'{data_op} in {one_input}')

# final
format_inp = (one_input[0], data_value, data_type, data_op)
formatted_inputs.append(format_inp)

return formatted_inputs


def build_inputs(inputs, fun_inputs):
"""Build input function.

Parameters
----------
inputs : tuple, list
The inputs of the population.
fun_inputs: optional, callable
The input function customized by users.

Returns
-------
func: callable
The input function.
"""

fix_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []}
next_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []}
func_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []}
array_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []}

if not (fun_inputs is None or callable(fun_inputs)):
raise ValueError

_has_iter_array = False
for variable, value, type_, op in inputs:
# variable
if not isinstance(variable, bm.Variable):
raise RunningError(f'{variable}\n is not a dynamically changed Variable, '
f'its value will not change, we think there is no need to '
f'give its input.')

# input data
if type_ == 'iter':
if isinstance(value, (bm.ndarray, np.ndarray, jnp.ndarray)):
array_inputs[op].append([variable, bm.asarray(value)])
_has_iter_array = True
else:
next_inputs[op].append([variable, iter(value)])
elif type_ == 'func':
func_inputs[op].append([variable, value])
else:
fix_inputs[op].append([variable, value])

def _f_ops(ops, var, data):
if ops == '=':
var[:] = data
elif ops == '+':
var += data
elif ops == '-':
var -= data
elif ops == '*':
var *= data
elif ops == '/':
var /= data
else:
raise ValueError(f'Unknown input operation: {ops}')

def func(tdi):
if fun_inputs is not None:
fun_inputs(tdi)
for ops, values in fix_inputs.items():
for var, data in values:
_f_ops(ops, var, data)
for ops, values in array_inputs.items():
for var, data in values:
_f_ops(ops, var, data[tdi['i']])
for ops, values in func_inputs.items():
for var, data in values:
_f_ops(ops, var, data(tdi['t'], tdi['dt']))
for ops, values in next_inputs.items():
for var, data in values:
_f_ops(ops, var, next(data))

return func, _has_iter_array


class DSRunner(Runner):
"""The runner for dynamical systems.
@@ -26,6 +245,7 @@ class DSRunner(Runner):
----------
target : DynamicalSystem
The target model to run.

inputs : list, tuple
The inputs for the target DynamicalSystem. It should be the format
of `[(target, value, [type, operation])]`, where `target` is the
@@ -40,267 +260,367 @@ class DSRunner(Runner):
- ``operation``: should be a string, support `+`, `-`, `*`, `/`, `=`.
- Also, if you want to specify multiple inputs, just give multiple ``(target, value, [type, operation])``,
for example ``[(target1, value1), (target2, value2)]``.

fun_inputs: callable
The functional inputs. Manually specify the inputs for the target variables.
This input function should receive one argument `shared` which contains the shared arguments like
time `t`, time step `dt`, and index `i`.

monitors: None, sequence of str, dict, Monitor
Variables to monitor.

- A list of string. Like `monitors=['a', 'b', 'c']`
- A list of string with index specification. Like `monitors=[('a', 1), ('b', [1,3,5]), 'c']`
- A dict with the explicit monitor target, like: `monitors={'a': model.spike, 'b': model.V}`
- A dict with the index specification, like: `monitors={'a': (model.spike, 0), 'b': (model.V, [1,2])}`

fun_monitors: dict
Monitoring variables by callable functions. Should be a dict.
The `key` should be a string for the later retrieval by `runner.mon[key]`.
The `value` should be a callable function which receives two arguments: `t` and `dt`.

jit: bool, dict
The JIT settings.

progress_bar: bool
Use progress bar to report the running progress or not?

dyn_vars: Optional, dict
The dynamically changed variables. Instance of :py:class:`~.Variable`.

numpy_mon_after_run : bool
When finishing the network running, transform the JAX arrays into numpy ndarray or not?

"""

def __init__(self, target: DynamicalSystem, inputs=(), dt=None, **kwargs):
target: DynamicalSystem

def __init__(
self,
target: DynamicalSystem,

# inputs for target variables
inputs: Sequence = (),
fun_inputs: Callable = None,

# extra info
dt: float = None,
t0: Union[float, int] = 0.,
**kwargs
):
if not isinstance(target, DynamicalSystem):
raise RunningError(f'"target" must be an instance of {DynamicalSystem.__name__}, '
f'but we got {type(target)}: {target}')
super(DSRunner, self).__init__(target=target, **kwargs)

# t0 and i0
self._t0 = t0
self.i0 = 0
self.t0 = check_float(t0, 't0', allow_none=False, allow_int=True)

# parameters
dt = bm.get_dt() if dt is None else dt
if not isinstance(dt, (int, float)):
raise RunningError(f'"dt" must be scalar, but got {dt}')
self.dt = dt
if not isinstance(target, DynamicalSystem):
raise RunningError(f'"target" must be an instance of {DynamicalSystem.__name__}, '
f'but we got {type(target)}: {target}')

# Build the monitor function
self._monitor_step = self.build_monitors()

# whether has iterable input data
self._has_iter_array = False # default do not have iterable input array
self._i = bm.Variable(bm.asarray([0]))
self._mon_info = self.format_monitors()

# Build input function
inputs = utils.check_and_format_inputs(host=target, inputs=inputs)
self._input_step = self.build_inputs(inputs)

# start simulation time
self._start_t = None

# JAX does not support iterator in fori_loop, scan, etc.
# https://github.com/google/jax/issues/3567
# We use Variable i to index the current input data.
if self._has_iter_array: # must behind of "self.build_input()"
self.dyn_vars.update({'_i': self._i})
else:
self._i = None
self._input_step, _ = build_inputs(check_and_format_inputs(host=target, inputs=inputs),
fun_inputs=fun_inputs)

# run function
self._run_func = self.build_run_function()

def build_inputs(self, inputs):
fix_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []}
next_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []}
func_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []}
array_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []}

for target, key, value, type_, op in inputs:
# variable
variable = getattr(target, key)
if not isinstance(variable, bm.Variable):
raise RunningError(f'"{key}" in {target} is not a dynamically changed Variable, '
f'its value will not change, we think there is no need to '
f'give its input.')

# input data
if type_ == 'iter':
if isinstance(value, (bm.ndarray, np.ndarray, jnp.ndarray)):
array_inputs[op].append([variable, bm.asarray(value)])
self._has_iter_array = True
else:
next_inputs[op].append([variable, iter(value)])
elif type_ == 'func':
func_inputs[op].append([variable, value])
else:
fix_inputs[op].append([variable, value])

def _f_ops(ops, var, data):
if ops == '=':
var[:] = data
elif ops == '+':
var += data
elif ops == '-':
var -= data
elif ops == '*':
var *= data
elif ops == '/':
var /= data
else:
raise ValueError

def func(_t, _dt):
for ops, values in fix_inputs.items():
for var, data in values:
_f_ops(ops, var, data)
for ops, values in array_inputs.items():
for var, data in values:
_f_ops(ops, var, data[self._i[0]])
for ops, values in func_inputs.items():
for var, data in values:
_f_ops(ops, var, data(_t, _dt))
for ops, values in next_inputs.items():
for var, data in values:
_f_ops(ops, var, next(data))
if self._has_iter_array:
self._i += 1

return func
self._f_predict_compiled = dict()

def build_monitors(self):
monitors = utils.check_and_format_monitors(host=self.target, mon=self.mon)

return_with_idx = dict()
return_without_idx = dict()
for key, target, variable, idx, interval in monitors:
if interval is not None:
raise ValueError(f'Running with "{self.__class__.__name__}" does '
f'not support "interval" in the monitor.')
data = target
for k in variable.split('.'):
data = getattr(data, k)
if not isinstance(data, bm.Variable):
raise RunningError(f'"{key}" in {target} is not a dynamically changed Variable, '
f'its value will not change, we think there is no need to '
f'monitor its trajectory.')
if idx is None:
return_without_idx[key] = data
else:
return_with_idx[key] = (data, bm.asarray(idx))

def func(_t, _dt):
res = {k: (v.flatten() if bm.ndim(v) > 1 else v.value)
for k, v in return_without_idx.items()}
res.update({k: (v.flatten()[idx] if bm.ndim(v) > 1 else v[idx])
for k, (v, idx) in return_with_idx.items()})
res.update({k: f(_t, _dt) for k, f in self.fun_monitors.items()})
def build_monitors(self, return_without_idx, return_with_idx, shared_args: dict):
def func(tdi):
res = {k: v.value for k, v in return_without_idx.items()}
res.update({k: v[idx] for k, (v, idx) in return_with_idx.items()})
res.update({k: f(tdi) for k, f in self.fun_monitors.items()})
return res

return func

def _run_one_step(self, _t):
self._input_step(_t, self.dt)
self.target.update(_t, self.dt)
if self.progress_bar:
id_tap(lambda *args: self._pbar.update(), ())
return self._monitor_step(_t, self.dt)

def build_run_function(self):
if self.jit:
dyn_vars = TensorCollector()
dyn_vars.update(self.dyn_vars)
dyn_vars.update(self.target.vars().unique())
f_run = bm.make_loop(self._run_one_step,
dyn_vars=dyn_vars,
has_return=True)
else:
def f_run(all_t):
for i in range(all_t.shape[0]):
mon = self._run_one_step(all_t[i])
for k, v in mon.items():
self.mon.item_contents[k].append(v)
return None, {}
return f_run

def run(self, duration, start_t=None):
return self.__call__(duration, start_t=start_t)

def __call__(self, duration, start_t=None):
"""The running function.
def reset_state(self):
self.i0 = 0
self.t0 = check_float(self._t0, 't0', allow_none=False, allow_int=True)

def predict(
self,
duration: Union[float, int] = None,
inputs: Union[Array, Sequence[Array], Dict[str, Array]] = None,
inputs_are_batching: bool = False,
reset_state: bool = False,
shared_args: Dict = None,
progress_bar: bool = True,
eval_time: bool = False
) -> Output:
"""Running a duration with the given target model. See `.predict()` function
for more details.

This function use the JIT compilation to accelerate the model simulation.
Moreover, it can automatically monitor the node variables, states, inputs,
feedbacks and its output.

Parameters
----------
duration : float, int, tuple, list
The running duration.
start_t : float, optional
duration: int, float
The simulation time length.
inputs: Array, dict of Array, sequence of Array
The input data. If ``inputs_are_batching=True``, ``inputs`` must be a
PyTree of data with two dimensions: `(num_sample, num_time, ...)`.
Otherwise, the ``inputs`` should be a PyTree of data with one dimension:
`(num_time, ...)`.
inputs_are_batching: bool
Whether the ``inputs`` are batching. If `True`, the batching axis is the
first dimension.
reset_state: bool
Whether reset the model states.
shared_args: optional, dict
The shared arguments across different layers.
progress_bar: bool
Whether report the progress of the simulation using progress bar.
eval_time: bool
Whether ro evaluate the running time.

Returns
-------
running_time : float
The total running time.
output: Array, dict, sequence
The model output.
"""
# time step
if start_t is None:
if self._start_t is None:
start_t = 0.
else:
start_t = float(self._start_t)
end_t = float(start_t + duration)
# times
times = np.arange(start_t, end_t, self.dt)

# shared arguments
if shared_args is None: shared_args = dict()
shared_args['fit'] = shared_args.get('fit', False)

# times and inputs
times, indices, xs, num_step, num_batch, duration, description = self._format_xs(
duration, inputs, inputs_are_batching)

# reset the states of the model and the runner
if reset_state:
self.target.reset_state(num_batch)
self.reset_state()
indices += self.i0
times += self.t0

# build monitor
for key in self.mon.item_contents.keys():
self.mon.item_contents[key] = [] # reshape the monitor items
# running
if self.progress_bar:
self._pbar = tqdm.auto.tqdm(total=times.size)
self._pbar.set_description(f"Running a duration of {round(float(duration), 3)} ({times.size} steps)",
refresh=True)
t0 = time.time()
_, hists = self._run_func(times)
running_time = time.time() - t0
if self.progress_bar:
self._pbar.close()
# post-running
if self.jit:
self.mon.ts = times + self.dt
for key in self.mon.item_names:
self.mon.item_contents[key] = bm.asarray(hists[key])
else:
self.mon.ts = times + self.dt
for key in self.mon.item_names:
self.mon.item_contents[key] = bm.asarray(self.mon.item_contents[key])
self._start_t = end_t
if self.numpy_mon_after_run:
self.mon.numpy()
return running_time
for key in self.mon.var_names:
self.mon[key] = [] # reshape the monitor items

# init progress bar
if self.progress_bar and progress_bar:
self._pbar = tqdm.auto.tqdm(total=num_step)
self._pbar.set_description(description, refresh=True)

class StructRunner(DSRunner):
"""The runner with the structural for-loop.
# running
if eval_time: t0 = time.time()
outputs, hists = self._predict(xs=(times, indices, xs), shared_args=shared_args)
if eval_time: running_time = time.time() - t0

.. deprecated:: 2.0.3
Prefer the use of :py:class:`brainpy.dyn.DSRunner` for dynamical system running.
This runner is deprecated since 2.0.3.
"""
# format
if inputs_are_batching:
outputs = tree_map(lambda x: bm.moveaxis(x, 0, 1), outputs, is_leaf=lambda x: isinstance(x, bm.JaxArray))
hists = tree_map(lambda x: bm.moveaxis(x, 0, 1), hists, is_leaf=lambda x: isinstance(x, bm.JaxArray))

def __init__(self, target, *args, **kwargs):
super(StructRunner, self).__init__(target, *args, **kwargs)
# close the progress bar
if self.progress_bar and progress_bar:
self._pbar.close()

# post-running for monitors
hists['ts'] = times + self.dt
if self.numpy_mon_after_run:
hists = tree_map(lambda a: np.asarray(a), hists, is_leaf=lambda a: isinstance(a, bm.JaxArray))
for key in hists.keys():
self.mon[key] = hists[key]
self.i0 += times.shape[0]
self.t0 += duration
return outputs if not eval_time else (running_time, outputs)

def _predict(
self,
xs: Sequence,
shared_args: Dict = None,
) -> Union[Output, Monitor]:
"""Predict the output according to the inputs.

class ReportRunner(DSRunner):
"""The runner provides convenient interface for debugging.
It is also able to report the running progress.
Parameters
----------
xs: sequence
Must be a tuple/list of data, including `(times, indices, inputs)`.
If `inputs` is not None, it should be a tensor with the shape of
:math:`(num_time, ...)`.
shared_args: optional, dict
The shared keyword arguments.

.. deprecated:: 2.0.3
Prefer the use of :py:class:`brainpy.dyn.DSRunner` for dynamical system running.
This runner is deprecated since 2.0.3.
Returns
-------
outputs, hists
A tuple of pair of (outputs, hists).
"""
_predict_func = self.f_predict(shared_args)
outputs, hists = _predict_func(xs)
return outputs, hists

Parameters
----------
target : DynamicalSystem
The target model to run.
monitors : None, list of str, tuple of str, Monitor
Variables to monitor.
inputs : list, tuple
The input settings.
"""
def run(self, *args, **kwargs) -> Output:
"""Predict a series of input data with the given target model.

def __init__(self, target, inputs=(), jit=False, dt=None, **kwargs):
super(ReportRunner, self).__init__(target=target, inputs=inputs, dt=dt, jit=False, **kwargs)
This function use the JIT compilation to accelerate the model simulation.
Moreover, it can automatically monitor the node variables, states, inputs,
feedbacks and its output.

# Build the update function
if jit:
dyn_vars = TensorCollector()
dyn_vars.update(self.dyn_vars)
dyn_vars.update(self.target.vars().unique())
self._update_step = bm.jit(self.target.update, dyn_vars=dyn_vars)
Parameters
----------
duration: int, float
The simulation time length.
inputs: Array, dict of Array, sequence of Array
The input data. If ``inputs_are_batching=True``, ``inputs`` must be a
PyTree of data with two dimensions: `(num_sample, num_time, ...)`.
Otherwise, the ``inputs`` should be a PyTree of data with one dimension:
`(num_time, ...)`.
inputs_are_batching: bool
Whether the ``inputs`` are batching. If `True`, the batching axis is the
first dimension.
reset_state: bool
Whether reset the model states.
shared_args: optional, dict
The shared arguments across different layers.
progress_bar: bool
Whether report the progress of the simulation using progress bar.
eval_time: bool
Whether to evaluate the running time.

Returns
-------
output: Array, dict, sequence
The model output.
"""
return self.predict(*args, **kwargs)

def __call__(self, *args, **kwargs) -> Output:
return self.predict(*args, **kwargs)

def _format_xs(self, duration, inputs, inputs_are_batching=True, move_axis=True):
if duration is None:
if inputs is None:
raise ValueError('"duration" and "inputs" can not both be None.')
xs, num_step, num_batch = self._check_xs(inputs,
move_axis=move_axis,
inputs_are_batching=inputs_are_batching)
indices = jax.device_put(jnp.arange(num_step))
times = jax.device_put(indices * self.dt)
description = f'Predict {num_step} steps: '
duration = num_step * self.dt
else:
self._update_step = self.target.update

def _run_one_step(self, _t):
self._input_step(_t, self.dt)
self._update_step(_t, self.dt)
if self.progress_bar:
self._pbar.update()
return self._monitor_step(_t, self.dt)

def build_run_function(self):
def f_run(all_t):
for i in range(all_t.shape[0]):
mon = self._run_one_step(all_t[i])
for k, v in mon.items():
self.mon.item_contents[k].append(v)
return None, {}

return f_run
times = jax.device_put(jnp.arange(0, duration, self.dt))
num_step = times.shape[0]
indices = jax.device_put(jnp.arange(num_step))
description = f'Running a duration of {round(float(duration), 3)} ({times.shape[0]} steps)'
if inputs is None:
xs, num_batch = None, None
else:
xs, num_step_, num_batch = self._check_xs(inputs,
move_axis=move_axis,
inputs_are_batching=inputs_are_batching)
if num_step != num_step:
raise ValueError('The step numbers of "time" and "inputs" '
f'do not match: {num_step_} != {num_step}.')
return times, indices, xs, num_step, num_batch, duration, description

def _check_xs(self, xs, move_axis=True, inputs_are_batching=True):
leaves, tree = tree_flatten(xs, is_leaf=lambda x: isinstance(x, bm.JaxArray))

# get information of time step and batch size
if inputs_are_batching:
num_times, num_batch_sizes = [], []
for val in leaves:
num_batch_sizes.append(val.shape[0])
num_times.append(val.shape[1])
else:
num_times = [val.shape[0] for val in leaves]
if len(set(num_times)) != 1:
raise ValueError(f'Number of time step is different across tensors in '
f'the provided "xs". We got {set(num_times)}.')
num_step = num_times[0]
if inputs_are_batching:
if len(set(num_batch_sizes)) != 1:
raise ValueError(f'Number of batch size is different across tensors in '
f'the provided "xs". We got {set(num_batch_sizes)}.')
num_batch = num_batch_sizes[0]
else:
num_batch = None

# change shape to (num_time, num_sample, num_feature)
if move_axis and inputs_are_batching:
xs = tree_map(lambda x: bm.moveaxis(x, 0, 1), xs,
is_leaf=lambda x: isinstance(x, bm.JaxArray))
return xs, num_step, num_batch

def f_predict(self, shared_args: Dict = None):
if shared_args is None: shared_args = dict()

shared_kwargs_str = serialize_kwargs(shared_args)
if shared_kwargs_str not in self._f_predict_compiled:

monitor_func = self.build_monitors(self._mon_info[0], self._mon_info[1], shared_args)

def _step_func(inputs):
t, i, x = inputs
self.target.clear_input()
# input step
shared = DotDict(t=t, i=i, dt=self.dt)
self._input_step(shared)
# dynamics update step
shared.update(shared_args)
args = (shared,) if x is None else (shared, x)
out = self.target(*args)
# monitor step
mon = monitor_func(shared)
# finally
if self.progress_bar:
id_tap(lambda *arg: self._pbar.update(), ())
return out, mon

if self.jit['predict']:
dyn_vars = self.target.vars()
dyn_vars.update(self.dyn_vars)
f = bm.make_loop(_step_func, dyn_vars=dyn_vars.unique(), has_return=True)
run_func = lambda all_inputs: f(all_inputs)[1]

else:
def run_func(xs):
# total data
times, indices, xs = xs

outputs = []
monitors = {key: [] for key in (set(self.mon.var_names) | set(self.fun_monitors.keys()))}
for i in range(times.shape[0]):
# data at time i
x = tree_map(lambda x: x[i], xs, is_leaf=lambda x: isinstance(x, bm.JaxArray))

# step at the i
output, mon = _step_func((times[i], indices[i], x))

# append output and monitor
outputs.append(output)
for key, value in mon.items():
monitors[key].append(value)

# final work
if outputs[0] is None:
outputs = None
else:
outputs = bm.asarray(outputs)
for key, value in monitors.items():
monitors[key] = bm.asarray(value)
return outputs, monitors
self._f_predict_compiled[shared_kwargs_str] = run_func
return self._f_predict_compiled[shared_kwargs_str]

def __del__(self):
if hasattr(self, '_predict_func'):
for key in tuple(self._f_predict_compiled.keys()):
del self._f_predict_compiled[key]
super(DSRunner, self).__del__()

+ 4
- 0
brainpy/dyn/synapses/__init__.py View File

@@ -3,4 +3,8 @@
from .abstract_models import *
from .biological_models import *
from .learning_rules import *
from .gap_junction import *
from .delay_couplings import *

# compatible interface
from . import compat

+ 404
- 810
brainpy/dyn/synapses/abstract_models.py View File

@@ -1,39 +1,43 @@
# -*- coding: utf-8 -*-

from typing import Union, Dict, Callable
from typing import Union, Dict, Callable, Optional

from jax import vmap
from jax.lax import stop_gradient

import brainpy.math as bm
from brainpy.connect import TwoEndConnector, All2All, One2One
from brainpy.dyn.base import NeuGroup, TwoEndConn
from brainpy.initialize import Initializer, init_param
from brainpy.dyn.base import NeuGroup, SynOut, SynSTP, TwoEndConn
from brainpy.initialize import Initializer, variable
from brainpy.integrators import odeint, JointEq
from brainpy.types import Tensor
from brainpy.modes import Mode, BatchingMode, normal
from brainpy.types import Array
from ..synouts import CUBA, MgBlock

__all__ = [
'DeltaSynapse',
'ExpCUBA',
'ExpCOBA',
'DualExpCUBA',
'DualExpCOBA',
'AlphaCUBA',
'AlphaCOBA',
'Delta',
'Exponential',
'DualExponential',
'Alpha',
'NMDA',
]


class DeltaSynapse(TwoEndConn):
"""Voltage Jump Synapse Model, or alias of Delta Synapse Model.
class Delta(TwoEndConn):
r"""Voltage Jump Synapse Model, or alias of Delta Synapse Model.

**Model Descriptions**

.. math::

I_{syn} (t) = \sum_{j\in C} w \delta(t-t_j-D)
I_{syn} (t) = \sum_{j\in C} g_{\mathrm{max}} * \mathrm{STP} * \delta(t-t_j-D)

where :math:`w` denotes the chemical synaptic strength, :math:`t_j` the spiking
moment of the presynaptic neuron :math:`j`, :math:`C` the set of neurons connected
to the post-synaptic neuron, and :math:`D` the transmission delay of chemical
synapses. For simplicity, the rise and decay phases of post-synaptic currents are
where :math:`g_{\mathrm{max}}` denotes the chemical synaptic strength,
:math:`t_j` the spiking moment of the presynaptic neuron :math:`j`,
:math:`C` the set of neurons connected to the post-synaptic neuron,
:math:`D` the transmission delay of chemical synapses,
and :math:`\mathrm{STP}` the short-term plasticity effect.
For simplicity, the rise and decay phases of post-synaptic currents are
omitted in this model.

**Model Examples**
@@ -42,11 +46,12 @@ class DeltaSynapse(TwoEndConn):
:include-source: True

>>> import brainpy as bp
>>> from brainpy.dyn import synapses, neurons
>>> import matplotlib.pyplot as plt
>>>
>>> neu1 = bp.dyn.LIF(1)
>>> neu2 = bp.dyn.LIF(1)
>>> syn1 = bp.dyn.DeltaSynapse(neu1, neu2, bp.connect.All2All(), weights=5.)
>>> neu1 = neurons.LIF(1)
>>> neu2 = neurons.LIF(1)
>>> syn1 = synapses.Alpha(neu1, neu2, bp.connect.All2All(), weights=5.)
>>> net = bp.dyn.Network(pre=neu1, syn=syn1, post=neu2)
>>>
>>> runner = bp.dyn.DSRunner(net, inputs=[('pre.input', 25.), ('post.input', 10.)], monitors=['pre.V', 'post.V', 'pre.spike'])
@@ -67,17 +72,14 @@ class DeltaSynapse(TwoEndConn):
The post-synaptic neuron group.
conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector
The synaptic connections.
conn_type: str
comp_method: str
The connection type used for model speed optimization. It can be
`sparse` and `dense`. The default is `sparse`.
delay_step: int, ndarray, JaxArray, Initializer, Callable
The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
weights: float, ndarray, JaxArray, Initializer, Callable
g_max: float, ndarray, JaxArray, Initializer, Callable
The synaptic strength. Default is 1.
post_key: str
The key of the post variable. It should be a string. The key should
be the attribute of the post-synaptic neuron group.
post_has_ref: bool
post_ref_key: str
Whether the post-synaptic group has refractory period.
"""

@@ -85,106 +87,86 @@ class DeltaSynapse(TwoEndConn):
self,
pre: NeuGroup,
post: NeuGroup,
conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]],
conn_type: str = 'sparse',
weights: Union[float, Tensor, Initializer, Callable] = 1.,
delay_step: Union[float, Tensor, Initializer, Callable] = None,
post_key: str = 'V',
post_has_ref: bool = False,
conn: Union[TwoEndConnector, Array, Dict[str, Array]],
output: SynOut = CUBA(target_var='V'),
stp: Optional[SynSTP] = None,
comp_method: str = 'sparse',
g_max: Union[float, Array, Initializer, Callable] = 1.,
delay_step: Union[float, Array, Initializer, Callable] = None,
post_ref_key: str = None,

# other parameters
name: str = None,
mode: Mode = normal,
stop_spike_gradient: bool = False,
):
super(DeltaSynapse, self).__init__(pre=pre, post=post, conn=conn, name=name)
self.check_pre_attrs('spike')
super(Delta, self).__init__(name=name,
pre=pre,
post=post,
conn=conn,
output=output,
stp=stp,
mode=mode)

# parameters
self.post_key = post_key
self.check_post_attrs(post_key)
self.post_has_ref = post_has_ref
if post_has_ref:
self.check_post_attrs('refractory')
self.stop_spike_gradient = stop_spike_gradient
self.post_ref_key = post_ref_key
if post_ref_key:
self.check_post_attrs(post_ref_key)
self.comp_method = comp_method

# connections and weights
self.conn_type = conn_type
if conn_type not in ['sparse', 'dense']:
raise ValueError(f'"conn_type" must be in "sparse" and "dense", but we got {conn_type}')
if self.conn is None:
raise ValueError(f'Must provide "conn" when initialize the model {self.name}')
if isinstance(self.conn, One2One):
self.weights = init_param(weights, (self.pre.num,), allow_none=False)
self.weight_type = 'heter' if bm.size(self.weights) != 1 else 'homo'
elif isinstance(self.conn, All2All):
self.weights = init_param(weights, (self.pre.num, self.post.num), allow_none=False)
if bm.size(self.weights) != 1:
self.weight_type = 'heter'
bm.fill_diagonal(self.weights, 0.)
else:
self.weight_type = 'homo'
else:
if conn_type == 'sparse':
self.pre2post = self.conn.require('pre2post')
self.weights = init_param(weights, self.pre2post[1].shape, allow_none=False)
self.weight_type = 'heter' if bm.size(self.weights) != 1 else 'homo'
elif conn_type == 'dense':
self.weights = init_param(weights, (self.pre.num, self.post.num), allow_none=False)
self.weight_type = 'heter' if bm.size(self.weights) != 1 else 'homo'
if self.weight_type == 'homo':
self.conn_mat = self.conn.require('conn_mat')
else:
raise ValueError(f'Unknown connection type: {conn_type}')
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method=comp_method, sparse_data='csr')

# variables
self.delay_step = self.register_delay(f"{self.pre.name}.spike",
delay_step=delay_step,
delay_target=self.pre.spike)
# register delay
self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)

def reset(self):
if self.delay_step is not None:
self.reset_delay(f"{self.pre.name}.spike", self.pre.spike)
def reset_state(self, batch_size=None):
self.output.reset_state(batch_size)
if self.stp is not None: self.stp.reset_state(batch_size)

def update(self, t, dt):
# delays
if self.delay_step is None:
pre_spike = self.pre.spike
else:
def update(self, tdi, pre_spike=None):
# pre-synaptic spikes
if pre_spike is None:
pre_spike = self.get_delay_data(f"{self.pre.name}.spike", delay_step=self.delay_step)
self.update_delay(f"{self.pre.name}.spike", delay_data=self.pre.spike)
if self.stop_spike_gradient:
pre_spike = pre_spike.value if isinstance(pre_spike, bm.JaxArray) else pre_spike
pre_spike = stop_gradient(pre_spike)

# post values
assert self.weight_type in ['homo', 'heter']
assert self.conn_type in ['sparse', 'dense']
# update sub-components
self.output.update(tdi)
if self.stp is not None: self.stp.update(tdi, pre_spike)

# synaptic values onto the post
if isinstance(self.conn, All2All):
if self.weight_type == 'homo':
post_vs = bm.sum(pre_spike)
if not self.conn.include_self:
post_vs = post_vs - pre_spike
post_vs *= self.weights
else:
post_vs = bm.expand_dims(pre_spike, 1) * self.weights
post_vs = post_vs.sum(axis=0)
syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype()))
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
post_vs = pre_spike * self.weights
syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype()))
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
else:
if self.conn_type == 'sparse':
post_vs = bm.pre2post_event_sum(pre_spike,
self.pre2post,
self.post.num,
self.weights)
if self.comp_method == 'sparse':
f = lambda s: bm.pre2post_event_sum(s, self.conn_mask, self.post.num, self.g_max)
if isinstance(self.mode, BatchingMode): f = vmap(f)
post_vs = f(pre_spike)
# if not isinstance(self.stp, _NullSynSTP):
# raise NotImplementedError()
# stp_value = self.stp(1.)
# f2 = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask)
# if self.trainable: f2 = vmap(f2)
# post_vs *= f2(stp_value)
else:
if self.weight_type == 'homo':
post_vs = self.weights * (pre_spike @ self.conn_mat)
else:
post_vs = pre_spike @ self.weights
syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype()))
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
if self.post_ref_key:
post_vs = post_vs * (1. - getattr(self.post, self.post_ref_key))

# update outputs
target = getattr(self.post, self.post_key)
if self.post_has_ref:
target += post_vs * bm.logical_not(self.post.refractory)
else:
target += post_vs
return self.output(post_vs)


class ExpCUBA(TwoEndConn):
r"""Current-based exponential decay synapse model.
class Exponential(TwoEndConn):
r"""Exponential decay synapse model.

**Model Descriptions**

@@ -206,17 +188,13 @@ class ExpCUBA(TwoEndConn):
.. math::

\begin{aligned}
& g_{\mathrm{syn}}(t) = g_{max} g \\
& g_{\mathrm{syn}}(t) = g_{max} g * \mathrm{STP} \\
& \frac{d g}{d t} = -\frac{g}{\tau_{decay}}+\sum_{k} \delta(t-t_{j}^{k}).
\end{aligned}

For the current output onto the post-synaptic neuron, its expression is given by

.. math::

I_{\mathrm{syn}}(t) = g_{\mathrm{syn}}(t)


where :math:`\mathrm{STP}` is used to model the short-term plasticity effect.
**Model Examples**

- `(Brunel & Hakim, 1999) Fast Global Oscillation <https://brainpy-examples.readthedocs.io/en/latest/oscillation_synchronization/Brunel_Hakim_1999_fast_oscillation.html>`_
@@ -228,11 +206,13 @@ class ExpCUBA(TwoEndConn):
:include-source: True

>>> import brainpy as bp
>>> from brainpy.dyn import neurons, synapses, synouts
>>> import matplotlib.pyplot as plt
>>>
>>> neu1 = bp.dyn.LIF(1)
>>> neu2 = bp.dyn.LIF(1)
>>> syn1 = bp.dyn.ExpCUBA(neu1, neu2, bp.conn.All2All(), g_max=5.)
>>> neu1 = neurons.LIF(1)
>>> neu2 = neurons.LIF(1)
>>> syn1 = synapses.Exponential(neu1, neu2, bp.conn.All2All(),
>>> g_max=5., output=synouts.CUBA())
>>> net = bp.dyn.Network(pre=neu1, syn=syn1, post=neu2)
>>>
>>> runner = bp.dyn.DSRunner(net, inputs=[('pre.input', 25.)], monitors=['pre.V', 'post.V', 'syn.g'])
@@ -257,7 +237,7 @@ class ExpCUBA(TwoEndConn):
The post-synaptic neuron group.
conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector
The synaptic connections.
conn_type: str
comp_method: str
The connection type used for model speed optimization. It can be
`sparse` and `dense`. The default is `sparse`.
delay_step: int, ndarray, JaxArray, Initializer, Callable
@@ -277,342 +257,223 @@ class ExpCUBA(TwoEndConn):
.. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw.
"The Synapse." Principles of Computational Modelling in Neuroscience.
Cambridge: Cambridge UP, 2011. 172-95. Print.

"""

def __init__(
self,
pre: NeuGroup,
post: NeuGroup,
conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]],
conn_type: str = 'sparse',
g_max: Union[float, Tensor, Initializer, Callable] = 1.,
delay_step: Union[int, Tensor, Initializer, Callable] = None,
tau: Union[float, Tensor] = 8.0,
name: str = None,
conn: Union[TwoEndConnector, Array, Dict[str, Array]],
output: SynOut = CUBA(),
stp: Optional[SynSTP] = None,
comp_method: str = 'sparse',
g_max: Union[float, Array, Initializer, Callable] = 1.,
delay_step: Union[int, Array, Initializer, Callable] = None,
tau: Union[float, Array] = 8.0,
method: str = 'exp_auto',
):
super(ExpCUBA, self).__init__(pre=pre, post=post, conn=conn, name=name)
self.check_pre_attrs('spike')
self.check_post_attrs('input', 'V')

# other parameters
name: str = None,
mode: Mode = normal,
stop_spike_gradient: bool = False,
):
super(Exponential, self).__init__(pre=pre,
post=post,
conn=conn,
output=output,
stp=stp,
name=name,
mode=mode)
# parameters
self.stop_spike_gradient = stop_spike_gradient
self.comp_method = comp_method
self.tau = tau
if bm.size(self.tau) != 1:
raise ValueError(f'"tau" must be a scalar or a tensor with size of 1. '
f'But we got {self.tau}')
raise ValueError(f'"tau" must be a scalar or a tensor with size of 1. But we got {self.tau}')

# connections and weights
self.conn_type = conn_type
if conn_type not in ['sparse', 'dense']:
raise ValueError(f'"conn_type" must be in "sparse" and "dense", but we got {conn_type}')
if self.conn is None:
raise ValueError(f'Must provide "conn" when initialize the model {self.name}')
if isinstance(self.conn, One2One):
self.g_max = init_param(g_max, (self.pre.num,), allow_none=False)
self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo'
elif isinstance(self.conn, All2All):
self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False)
if bm.size(self.g_max) != 1:
self.weight_type = 'heter'
bm.fill_diagonal(self.g_max, 0.)
else:
self.weight_type = 'homo'
else:
if conn_type == 'sparse':
self.pre2post = self.conn.require('pre2post')
self.g_max = init_param(g_max, self.pre2post[1].shape, allow_none=False)
self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo'
elif conn_type == 'dense':
self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False)
self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo'
if self.weight_type == 'homo':
self.conn_mat = self.conn.require('conn_mat')
else:
raise ValueError(f'Unknown connection type: {conn_type}')
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='csr')

# variables
self.g = bm.Variable(bm.zeros(self.post.num))
self.g = variable(bm.zeros, mode, self.post.num)
self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)

# function
self.integral = odeint(lambda g, t: -g / self.tau, method=method)

def reset(self):
self.g.value = bm.zeros(self.post.num)
if self.delay_step is not None:
self.reset_delay(f"{self.pre.name}.spike", self.pre.spike)
def reset_state(self, batch_size=None):
self.g.value = variable(bm.zeros, batch_size, self.post.num)
self.output.reset_state(batch_size)
if self.stp is not None: self.stp.reset_state(batch_size)

def update(self, tdi, pre_spike=None):
t, dt = tdi['t'], tdi['dt']

def update(self, t, dt):
# delays
if self.delay_step is None:
pre_spike = self.pre.spike
else:
if pre_spike is None:
pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step)
self.update_delay(f"{self.pre.name}.spike", self.pre.spike)
if self.stop_spike_gradient:
pre_spike = pre_spike.value if isinstance(pre_spike, bm.JaxArray) else pre_spike
pre_spike = stop_gradient(pre_spike)

# update sub-components
self.output.update(tdi)
if self.stp is not None: self.stp.update(tdi, pre_spike)

# post values
assert self.weight_type in ['homo', 'heter']
assert self.conn_type in ['sparse', 'dense']
if isinstance(self.conn, All2All):
if self.weight_type == 'homo':
post_vs = bm.sum(pre_spike)
if not self.conn.include_self:
post_vs = post_vs - pre_spike
post_vs = self.g_max * post_vs
else:
post_vs = pre_spike @ self.g_max
syn_value = bm.asarray(pre_spike, dtype=bm.dftype())
if self.stp is not None: syn_value = self.stp(syn_value)
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
post_vs = pre_spike * self.g_max
syn_value = bm.asarray(pre_spike, dtype=bm.dftype())
if self.stp is not None: syn_value = self.stp(syn_value)
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
else:
if self.conn_type == 'sparse':
post_vs = bm.pre2post_event_sum(pre_spike,
self.pre2post,
self.post.num,
self.g_max)
if self.comp_method == 'sparse':
f = lambda s: bm.pre2post_event_sum(s, self.conn_mask, self.post.num, self.g_max)
if isinstance(self.mode, BatchingMode): f = vmap(f)
post_vs = f(pre_spike)
# if not isinstance(self.stp, _NullSynSTP):
# raise NotImplementedError()
else:
if self.weight_type == 'homo':
post_vs = self.g_max * (pre_spike @ self.conn_mat)
else:
post_vs = pre_spike @ self.g_max

syn_value = bm.asarray(pre_spike, dtype=bm.dftype())
if self.stp is not None: syn_value = self.stp(syn_value)
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
# updates
self.g.value = self.integral(self.g.value, t, dt=dt) + post_vs
self.post.input += self.output(self.g)

def output(self, g_post):
return g_post


class ExpCOBA(ExpCUBA):
"""Conductance-based exponential decay synapse model.

**Model Descriptions**

The conductance-based exponential decay synapse model is similar with the
`current-based exponential decay synapse model <./brainmodels.synapses.ExpCUBA.rst>`_,
except the expression which output onto the post-synaptic neurons:

.. math::

I_{syn}(t) = g_{\mathrm{syn}}(t) (V(t)-E)

where :math:`V(t)` is the membrane potential of the post-synaptic neuron,
:math:`E` is the reversal potential.

**Model Examples**

- `(Brette, et, al., 2007) COBA <https://brainpy-examples.readthedocs.io/en/latest/ei_nets/Brette_2007_COBA.html>`_
- `(Brette, et, al., 2007) COBAHH <https://brainpy-examples.readthedocs.io/en/latest/ei_nets/Brette_2007_COBAHH.html>`_


.. plot::
:include-source: True

>>> import brainpy as bp
>>> import matplotlib.pyplot as plt
>>>
>>> neu1 = bp.dyn.HH(1)
>>> neu2 = bp.dyn.HH(1)
>>> syn1 = bp.dyn.ExpCOBA(neu1, neu2, bp.connect.All2All(), E=0.)
>>> net = bp.dyn.Network(pre=neu1, syn=syn1, post=neu2)
>>>
>>> runner = bp.dyn.DSRunner(net, inputs=[('pre.input', 5.)], monitors=['pre.V', 'post.V', 'syn.g'])
>>> runner.run(150.)
>>>
>>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
>>> fig.add_subplot(gs[0, 0])
>>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V')
>>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V')
>>> plt.legend()
>>>
>>> fig.add_subplot(gs[1, 0])
>>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g')
>>> plt.legend()
>>> plt.show()

Parameters
----------
pre: NeuGroup
The pre-synaptic neuron group.
post: NeuGroup
The post-synaptic neuron group.
conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector
The synaptic connections.
conn_type: str
The connection type used for model speed optimization. It can be
`sparse` and `dense`. The default is `sparse`.
delay_step: int, ndarray, JaxArray, Initializer, Callable
The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
E: float, JaxArray, ndarray
The reversal potential for the synaptic current. [mV]
tau: float, JaxArray, ndarray
The time constant of decay. [ms]
g_max: float, ndarray, JaxArray, Initializer, Callable
The synaptic strength (the maximum conductance). Default is 1.
name: str
The name of this synaptic projection.
method: str
The numerical integration methods.

References
----------
self.g.value = self.integral(self.g.value, t, dt) + post_vs

.. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw.
"The Synapse." Principles of Computational Modelling in Neuroscience.
Cambridge: Cambridge UP, 2011. 172-95. Print.
"""
# output
return self.output(self.g)


class DualExponential(TwoEndConn):
r"""Dual exponential synapse model.

**Model Descriptions**

The dual exponential synapse model [1]_, also named as *difference of two exponentials* model,
is given by:

.. math::

g_{\mathrm{syn}}(t)=g_{\mathrm{max}} \frac{\tau_{1} \tau_{2}}{
\tau_{1}-\tau_{2}}\left(\exp \left(-\frac{t-t_{0}}{\tau_{1}}\right)
-\exp \left(-\frac{t-t_{0}}{\tau_{2}}\right)\right)

where :math:`\tau_1` is the time constant of the decay phase, :math:`\tau_2`
is the time constant of the rise phase, :math:`t_0` is the time of the pre-synaptic
spike, :math:`g_{\mathrm{max}}` is the maximal conductance.

However, in practice, this formula is hard to implement. The equivalent solution is
two coupled linear differential equations [2]_:

.. math::

\begin{aligned}
&g_{\mathrm{syn}}(t)=g_{\mathrm{max}} g * \mathrm{STP} \\
&\frac{d g}{d t}=-\frac{g}{\tau_{\mathrm{decay}}}+h \\
&\frac{d h}{d t}=-\frac{h}{\tau_{\text {rise }}}+ \delta\left(t_{0}-t\right),
\end{aligned}

where :math:`\mathrm{STP}` is used to model the short-term plasticity effect of synapses.

**Model Examples**

.. plot::
:include-source: True

>>> import brainpy as bp
>>> from brainpy.dyn import neurons, synapses, synouts
>>> import matplotlib.pyplot as plt
>>>
>>> neu1 = neurons.LIF(1)
>>> neu2 = neurons.LIF(1)
>>> syn1 = synapses.DualExponential(neu1, neu2, bp.connect.All2All(), output=synouts.CUBA())
>>> net = bp.dyn.Network(pre=neu1, syn=syn1, post=neu2)
>>>
>>> runner = bp.dyn.DSRunner(net, inputs=[('pre.input', 25.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.h'])
>>> runner.run(150.)
>>>
>>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
>>> fig.add_subplot(gs[0, 0])
>>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V')
>>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V')
>>> plt.legend()
>>>
>>> fig.add_subplot(gs[1, 0])
>>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g')
>>> plt.plot(runner.mon.ts, runner.mon['syn.h'], label='h')
>>> plt.legend()
>>> plt.show()

Parameters
----------
pre: NeuGroup
The pre-synaptic neuron group.
post: NeuGroup
The post-synaptic neuron group.
conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector
The synaptic connections.
comp_method: str
The connection type used for model speed optimization. It can be
`sparse` and `dense`. The default is `sparse`.
delay_step: int, ndarray, JaxArray, Initializer, Callable
The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
tau_decay: float, JaxArray, JaxArray, ndarray
The time constant of the synaptic decay phase. [ms]
tau_rise: float, JaxArray, JaxArray, ndarray
The time constant of the synaptic rise phase. [ms]
g_max: float, ndarray, JaxArray, Initializer, Callable
The synaptic strength (the maximum conductance). Default is 1.
name: str
The name of this synaptic projection.
method: str
The numerical integration methods.

References
----------

.. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw.
"The Synapse." Principles of Computational Modelling in Neuroscience.
Cambridge: Cambridge UP, 2011. 172-95. Print.
.. [2] Roth, A., & Van Rossum, M. C. W. (2009). Modeling Synapses. Computational
Modeling Methods for Neuroscientists.

"""

def __init__(
self,
pre: NeuGroup,
post: NeuGroup,
# connection
conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]],
conn_type: str = 'sparse',
# connection strength
g_max: Union[float, Tensor, Initializer, Callable] = 1.,
# synapse parameter
tau: Union[float, Tensor] = 8.0,
E: Union[float, Tensor] = 0.,
# synapse delay
delay_step: Union[int, Tensor, Initializer, Callable] = None,
# others
conn: Union[TwoEndConnector, Array, Dict[str, Array]],
stp: Optional[SynSTP] = None,
output: SynOut = CUBA(),
comp_method: str = 'dense',
g_max: Union[float, Array, Initializer, Callable] = 1.,
tau_decay: Union[float, Array] = 10.0,
tau_rise: Union[float, Array] = 1.,
delay_step: Union[int, Array, Initializer, Callable] = None,
method: str = 'exp_auto',
name: str = None
):
super(ExpCOBA, self).__init__(pre=pre, post=post, conn=conn,
conn_type=conn_type,
g_max=g_max, delay_step=delay_step,
tau=tau, method=method, name=name)

# parameter
self.E = E
if bm.size(self.E) != 1:
raise ValueError(f'"E" must be a scalar or a tensor with size of 1. '
f'But we got {self.E}')

def output(self, g_post):
return g_post * (self.E - self.post.V)


class DualExpCUBA(TwoEndConn):
r"""Current-based dual exponential synapse model.

**Model Descriptions**

The dual exponential synapse model [1]_, also named as *difference of two exponentials* model,
is given by:

.. math::

g_{\mathrm{syn}}(t)=g_{\mathrm{max}} \frac{\tau_{1} \tau_{2}}{
\tau_{1}-\tau_{2}}\left(\exp \left(-\frac{t-t_{0}}{\tau_{1}}\right)
-\exp \left(-\frac{t-t_{0}}{\tau_{2}}\right)\right)

where :math:`\tau_1` is the time constant of the decay phase, :math:`\tau_2`
is the time constant of the rise phase, :math:`t_0` is the time of the pre-synaptic
spike, :math:`g_{\mathrm{max}}` is the maximal conductance.

However, in practice, this formula is hard to implement. The equivalent solution is
two coupled linear differential equations [2]_:

.. math::

\begin{aligned}
&g_{\mathrm{syn}}(t)=g_{\mathrm{max}} g \\
&\frac{d g}{d t}=-\frac{g}{\tau_{\mathrm{decay}}}+h \\
&\frac{d h}{d t}=-\frac{h}{\tau_{\text {rise }}}+ \delta\left(t_{0}-t\right),
\end{aligned}

The current onto the post-synaptic neuron is given by

.. math::

I_{syn}(t) = g_{\mathrm{syn}}(t).


**Model Examples**


.. plot::
:include-source: True

>>> import brainpy as bp
>>> import matplotlib.pyplot as plt
>>>
>>> neu1 = bp.dyn.LIF(1)
>>> neu2 = bp.dyn.LIF(1)
>>> syn1 = bp.dyn.DualExpCUBA(neu1, neu2, bp.connect.All2All())
>>> net = bp.dyn.Network(pre=neu1, syn=syn1, post=neu2)
>>>
>>> runner = bp.dyn.DSRunner(net, inputs=[('pre.input', 25.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.h'])
>>> runner.run(150.)
>>>
>>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
>>> fig.add_subplot(gs[0, 0])
>>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V')
>>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V')
>>> plt.legend()
>>>
>>> fig.add_subplot(gs[1, 0])
>>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g')
>>> plt.plot(runner.mon.ts, runner.mon['syn.h'], label='h')
>>> plt.legend()
>>> plt.show()

Parameters
----------
pre: NeuGroup
The pre-synaptic neuron group.
post: NeuGroup
The post-synaptic neuron group.
conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector
The synaptic connections.
conn_type: str
The connection type used for model speed optimization. It can be
`sparse` and `dense`. The default is `sparse`.
delay_step: int, ndarray, JaxArray, Initializer, Callable
The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
tau_decay: float, JaxArray, JaxArray, ndarray
The time constant of the synaptic decay phase. [ms]
tau_rise: float, JaxArray, JaxArray, ndarray
The time constant of the synaptic rise phase. [ms]
g_max: float, ndarray, JaxArray, Initializer, Callable
The synaptic strength (the maximum conductance). Default is 1.
name: str
The name of this synaptic projection.
method: str
The numerical integration methods.

References
----------

.. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw.
"The Synapse." Principles of Computational Modelling in Neuroscience.
Cambridge: Cambridge UP, 2011. 172-95. Print.
.. [2] Roth, A., & Van Rossum, M. C. W. (2009). Modeling Synapses. Computational
Modeling Methods for Neuroscientists.

"""

def __init__(
self,
pre: NeuGroup,
post: NeuGroup,
conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]],
conn_type: str = 'dense',
g_max: Union[float, Tensor, Initializer, Callable] = 1.,
tau_decay: Union[float, Tensor] = 10.0,
tau_rise: Union[float, Tensor] = 1.,
delay_step: Union[int, Tensor, Initializer, Callable] = None,
method: str = 'exp_auto',
name: str = None
# other parameters
name: str = None,
mode: Mode = normal,
stop_spike_gradient: bool = False,
):
super(DualExpCUBA, self).__init__(pre=pre, post=post, conn=conn, name=name)
self.check_pre_attrs('spike')
self.check_post_attrs('input')

super(DualExponential, self).__init__(pre=pre,
post=post,
conn=conn,
output=output,
stp=stp,
name=name,
mode=mode)
# parameters
# self.check_pre_attrs('spike')
self.check_post_attrs('input')
self.stop_spike_gradient = stop_spike_gradient
self.comp_method = comp_method
self.tau_rise = tau_rise
self.tau_decay = tau_decay
if bm.size(self.tau_rise) != 1:
@@ -623,47 +484,21 @@ class DualExpCUBA(TwoEndConn):
f'But we got {self.tau_decay}')

# connections
self.conn_type = conn_type
if conn_type not in ['sparse', 'dense']:
raise ValueError(f'"conn_type" must be in "sparse" and "dense", but we got {conn_type}')
if self.conn is None:
raise ValueError(f'Must provide "conn" when initialize the model {self.name}')
if isinstance(self.conn, One2One):
self.g_max = init_param(g_max, (self.pre.num,), allow_none=False)
self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo'
elif isinstance(self.conn, All2All):
self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False)
if bm.size(self.g_max) != 1:
self.weight_type = 'heter'
bm.fill_diagonal(self.g_max, 0.)
else:
self.weight_type = 'homo'
else:
if conn_type == 'sparse':
self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids')
self.g_max = init_param(g_max, self.post_ids.shape, allow_none=False)
self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo'
elif conn_type == 'dense':
self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False)
self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo'
if self.weight_type == 'homo':
self.conn_mat = self.conn.require('conn_mat')
else:
raise ValueError(f'Unknown connection type: {conn_type}')
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij')

# variables
self.h = bm.Variable(bm.zeros(self.pre.num))
self.g = bm.Variable(bm.zeros(self.pre.num))
self.h = variable(bm.zeros, mode, self.pre.num)
self.g = variable(bm.zeros, mode, self.pre.num)
self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)

# integral
self.integral = odeint(method=method, f=JointEq([self.dg, self.dh]))

def reset(self):
self.h.value = bm.zeros(self.pre.num)
self.g.value = bm.zeros(self.pre.num)
if self.delay_step is not None:
self.reset_delay(f"{self.pre.name}.spike", self.pre.spike)
def reset_state(self, batch_size=None):
self.h.value = variable(bm.zeros, batch_size, self.pre.num)
self.g.value = variable(bm.zeros, batch_size, self.pre.num)
self.output.reset_state(batch_size)
if self.stp is not None: self.stp.reset_state(batch_size)

def dh(self, h, t):
return -h / self.tau_rise
@@ -671,158 +506,45 @@ class DualExpCUBA(TwoEndConn):
def dg(self, g, t, h):
return -g / self.tau_decay + h

def update(self, t, dt):
# delays
if self.delay_step is None:
pre_spike = self.pre.spike
else:
def update(self, tdi, pre_spike=None):
t, dt = tdi['t'], tdi['dt']
# pre-synaptic spikes
if pre_spike is None:
pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step)
self.update_delay(f"{self.pre.name}.spike", self.pre.spike)
if self.stop_spike_gradient:
pre_spike = pre_spike.value if isinstance(pre_spike, bm.JaxArray) else pre_spike
pre_spike = stop_gradient(pre_spike)

# update sub-components
self.output.update(tdi)
if self.stp is not None: self.stp.update(tdi, pre_spike)

# update synaptic variables
self.g.value, self.h.value = self.integral(self.g, self.h, t, dt)
self.h += pre_spike

# post-synaptic values
assert self.weight_type in ['homo', 'heter']
assert self.conn_type in ['sparse', 'dense']
# post values
syn_value = self.g.value
if self.stp is not None: syn_value = self.stp(syn_value)
if isinstance(self.conn, All2All):
if self.weight_type == 'homo':
post_vs = bm.sum(self.g)
if not self.conn.include_self:
post_vs = post_vs - self.g
post_vs = self.g_max * post_vs
else:
post_vs = self.g @ self.g_max
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
post_vs = self.g_max * self.g
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
else:
if self.conn_type == 'sparse':
post_vs = bm.pre2post_sum(self.g, self.post.num, self.post_ids, self.pre_ids)
if self.comp_method == 'sparse':
f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask)
if isinstance(self.mode, BatchingMode): f = vmap(f)
post_vs = f(syn_value)
else:
if self.weight_type == 'homo':
post_vs = (self.g_max * self.g) @ self.conn_mat
else:
post_vs = self.g @ self.g_max
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)

# output
self.post.input += self.output(post_vs)

def output(self, g_post):
return g_post


class DualExpCOBA(DualExpCUBA):
"""Conductance-based dual exponential synapse model.

**Model Descriptions**

The conductance-based dual exponential synapse model is similar with the
`current-based dual exponential synapse model <./brainmodels.synapses.DualExpCUBA.rst>`_,
except the expression which output onto the post-synaptic neurons:

.. math::

I_{syn}(t) = g_{\mathrm{syn}}(t) (V(t)-E)

where :math:`V(t)` is the membrane potential of the post-synaptic neuron,
:math:`E` is the reversal potential.

**Model Examples**

.. plot::
:include-source: True

>>> import brainpy as bp
>>> import matplotlib.pyplot as plt
>>>
>>> neu1 = bp.dyn.HH(1)
>>> neu2 = bp.dyn.HH(1)
>>> syn1 = bp.dyn.DualExpCOBA(neu1, neu2, bp.connect.All2All(), E=0.)
>>> net = bp.dyn.Network(pre=neu1, syn=syn1, post=neu2)
>>>
>>> runner = bp.dyn.DSRunner(net, inputs=[('pre.input', 5.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.h'])
>>> runner.run(150.)
>>>
>>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
>>> fig.add_subplot(gs[0, 0])
>>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V')
>>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V')
>>> plt.legend()
>>>
>>> fig.add_subplot(gs[1, 0])
>>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g')
>>> plt.plot(runner.mon.ts, runner.mon['syn.h'], label='h')
>>> plt.legend()
>>> plt.show()

Parameters
----------
pre: NeuGroup
The pre-synaptic neuron group.
post: NeuGroup
The post-synaptic neuron group.
conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector
The synaptic connections.
conn_type: str
The connection type used for model speed optimization. It can be
`sparse` and `dense`. The default is `sparse`.
delay_step: int, ndarray, JaxArray, Initializer, Callable
The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
E: float, JaxArray, ndarray
The reversal potential for the synaptic current. [mV]
tau_decay: float, JaxArray, ndarray
The time constant of the synaptic decay phase. [ms]
tau_rise: float, JaxArray, ndarray
The time constant of the synaptic rise phase. [ms]
g_max: float, ndarray, JaxArray, Initializer, Callable
The synaptic strength (the maximum conductance). Default is 1.
name: str
The name of this synaptic projection.
method: str
The numerical integration methods.

References
----------

.. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw.
"The Synapse." Principles of Computational Modelling in Neuroscience.
Cambridge: Cambridge UP, 2011. 172-95. Print.

"""

def __init__(
self,
pre: NeuGroup,
post: NeuGroup,
conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]],
conn_type: str = 'dense',
g_max: Union[float, Tensor, Initializer, Callable] = 1.,
delay_step: Union[int, Tensor, Initializer, Callable] = None,
tau_decay: Union[float, Tensor] = 10.0,
tau_rise: Union[float, Tensor] = 1.,
E: Union[float, Tensor] = 0.,
method: str = 'exp_auto',
name: str = None
):
super(DualExpCOBA, self).__init__(pre, post, conn, conn_type=conn_type,
delay_step=delay_step, g_max=g_max,
tau_decay=tau_decay, tau_rise=tau_rise,
method=method, name=name)
self.check_post_attrs('V')

# parameters
self.E = E
if bm.size(self.E) != 1:
raise ValueError(f'"E" must be a scalar or a tensor with size of 1. '
f'But we got {self.E}')

def output(self, g_post):
return g_post * (self.E - self.post.V)
return self.output(post_vs)


class AlphaCUBA(DualExpCUBA):
r"""Current-based alpha synapse model.
class Alpha(DualExponential):
r"""Alpha synapse model.

**Model Descriptions**

@@ -843,24 +565,18 @@ class AlphaCUBA(DualExpCUBA):
&\frac{d h}{d t}=-\frac{h}{\tau}+\delta\left(t_{0}-t\right)
\end{aligned}

The current onto the post-synaptic neuron is given by

.. math::

I_{syn}(t) = g_{\mathrm{syn}}(t).


**Model Examples**

.. plot::
:include-source: True

>>> import brainpy as bp
>>> from brainpy.dyn import neurons, synapses, synouts
>>> import matplotlib.pyplot as plt
>>>
>>> neu1 = bp.dyn.LIF(1)
>>> neu2 = bp.dyn.LIF(1)
>>> syn1 = bp.dyn.AlphaCUBA(neu1, neu2, bp.connect.All2All())
>>> neu1 = neurons.LIF(1)
>>> neu2 = neurons.LIF(1)
>>> syn1 = synapses.Alpha(neu1, neu2, bp.connect.All2All(), output=synouts.CUBA())
>>> net = bp.dyn.Network(pre=neu1, syn=syn1, post=neu2)
>>>
>>> runner = bp.dyn.DSRunner(net, inputs=[('pre.input', 25.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.h'])
@@ -885,7 +601,7 @@ class AlphaCUBA(DualExpCUBA):
The post-synaptic neuron group.
conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector
The synaptic connections.
conn_type: str
comp_method: str
The connection type used for model speed optimization. It can be
`sparse` and `dense`. The default is `sparse`.
delay_step: int, ndarray, JaxArray, Initializer, Callable
@@ -911,126 +627,38 @@ class AlphaCUBA(DualExpCUBA):
self,
pre: NeuGroup,
post: NeuGroup,
conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]],
conn_type: str = 'dense',
g_max: Union[float, Tensor, Initializer, Callable] = 1.,
delay_step: Union[int, Tensor, Initializer, Callable] = None,
tau_decay: Union[float, Tensor] = 10.0,
conn: Union[TwoEndConnector, Array, Dict[str, Array]],
output: SynOut = CUBA(),
stp: Optional[SynSTP] = None,
comp_method: str = 'dense',
g_max: Union[float, Array, Initializer, Callable] = 1.,
delay_step: Union[int, Array, Initializer, Callable] = None,
tau_decay: Union[float, Array] = 10.0,
method: str = 'exp_auto',
name: str = None
):
super(AlphaCUBA, self).__init__(pre=pre, post=post, conn=conn,
conn_type=conn_type,
delay_step=delay_step,
g_max=g_max,
tau_decay=tau_decay,
tau_rise=tau_decay,
method=method,
name=name)


class AlphaCOBA(DualExpCOBA):
"""Conductance-based alpha synapse model.

**Model Descriptions**

The conductance-based alpha synapse model is similar with the
`current-based alpha synapse model <./brainmodels.synapses.AlphaCUBA.rst>`_,
except the expression which output onto the post-synaptic neurons:

.. math::

I_{syn}(t) = g_{\mathrm{syn}}(t) (V(t)-E)

where :math:`V(t)` is the membrane potential of the post-synaptic neuron,
:math:`E` is the reversal potential.


**Model Examples**

.. plot::
:include-source: True

>>> import brainpy as bp
>>> import matplotlib.pyplot as plt
>>>
>>> neu1 = bp.dyn.HH(1)
>>> neu2 = bp.dyn.HH(1)
>>> syn1 = bp.dyn.AlphaCOBA(neu1, neu2, bp.connect.All2All(), E=0.)
>>> net = bp.dyn.Network(pre=neu1, syn=syn1, post=neu2)
>>>
>>> runner = bp.dyn.DSRunner(net, inputs=[('pre.input', 5.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.h'])
>>> runner.run(150.)
>>>
>>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
>>> fig.add_subplot(gs[0, 0])
>>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V')
>>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V')
>>> plt.legend()
>>> fig.add_subplot(gs[1, 0])
>>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g')
>>> plt.plot(runner.mon.ts, runner.mon['syn.h'], label='h')
>>> plt.legend()
>>> plt.show()

Parameters
----------
pre: NeuGroup
The pre-synaptic neuron group.
post: NeuGroup
The post-synaptic neuron group.
conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector
The synaptic connections.
conn_type: str
The connection type used for model speed optimization. It can be
`sparse` and `dense`. The default is `dense`.
delay_step: int, ndarray, JaxArray, Initializer, Callable
The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
E: float, JaxArray, ndarray
The reversal potential for the synaptic current. [mV]
tau_decay: float, JaxArray, ndarray
The time constant of the synaptic decay phase. [ms]
g_max: float, ndarray, JaxArray, Initializer, Callable
The synaptic strength (the maximum conductance). Default is 1.
name: str
The name of this synaptic projection.
method: str
The numerical integration methods.

References
----------

.. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw.
"The Synapse." Principles of Computational Modelling in Neuroscience.
Cambridge: Cambridge UP, 2011. 172-95. Print.

"""

def __init__(
self,
pre: NeuGroup,
post: NeuGroup,
conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]],
conn_type: str = 'dense',
g_max: Union[float, Tensor, Callable, Initializer] = 1.,
delay_step: Union[int, Tensor, Initializer, Callable] = None,
tau_decay: Union[float, Tensor] = 10.0,
E: Union[float, Tensor] = 0.,
method: str = 'exp_auto',
name: str = None
# other parameters
name: str = None,
mode: Mode = normal,
stop_spike_gradient: bool = False,
):
super(AlphaCOBA, self).__init__(pre=pre, post=post, conn=conn,
conn_type=conn_type,
delay_step=delay_step,
g_max=g_max, E=E,
tau_decay=tau_decay,
tau_rise=tau_decay,
method=method,
name=name)
super(Alpha, self).__init__(pre=pre,
post=post,
conn=conn,
comp_method=comp_method,
delay_step=delay_step,
g_max=g_max,
tau_decay=tau_decay,
tau_rise=tau_decay,
method=method,
output=output,
stp=stp,
name=name,
mode=mode,
stop_spike_gradient=stop_spike_gradient)


class NMDA(TwoEndConn):
r"""Conductance-based NMDA synapse model.
r"""NMDA synapse model.

**Model Descriptions**

@@ -1062,7 +690,7 @@ class NMDA(TwoEndConn):

.. math::

I_{syn} = g_{NMDA}(t) (V(t)-E) \cdot g_{\infty}
I_{syn} = g_\mathrm{NMDA}(t) (V(t)-E) \cdot g_{\infty}

where :math:`V(t)` is the post-synaptic neuron potential, :math:`E` is the
reversal potential.
@@ -1071,7 +699,7 @@ class NMDA(TwoEndConn):

.. math::

& g_{NMDA} (t) = g_{max} g \\
& g_\mathrm{NMDA} (t) = g_{max} g \\
& \frac{d g}{dt} = -\frac{g} {\tau_{decay}}+a x(1-g) \\
& \frac{d x}{dt} = -\frac{x}{\tau_{rise}}+ \sum_{k} \delta(t-t_{j}^{k})

@@ -1091,11 +719,12 @@ class NMDA(TwoEndConn):
:include-source: True

>>> import brainpy as bp
>>> from brainpy.dyn import synapses, neurons
>>> import matplotlib.pyplot as plt
>>>
>>> neu1 = bp.dyn.HH(1)
>>> neu2 = bp.dyn.HH(1)
>>> syn1 = bp.dyn.NMDA(neu1, neu2, bp.connect.All2All(), E=0.)
>>> neu1 = neurons.HH(1)
>>> neu2 = neurons.HH(1)
>>> syn1 = synapses.NMDA(neu1, neu2, bp.connect.All2All(), E=0.)
>>> net = bp.dyn.Network(pre=neu1, syn=syn1, post=neu2)
>>>
>>> runner = bp.dyn.DSRunner(net, inputs=[('pre.input', 5.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.x'])
@@ -1121,21 +750,13 @@ class NMDA(TwoEndConn):
The post-synaptic neuron group.
conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector
The synaptic connections.
conn_type: str
comp_method: str
The connection type used for model speed optimization. It can be
`sparse` and `dense`. The default is `dense`.
delay_step: int, ndarray, JaxArray, Initializer, Callable
The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
g_max: float, ndarray, JaxArray, Initializer, Callable
The synaptic strength (the maximum conductance). Default is 1.
E: float, JaxArray, ndarray
The reversal potential for the synaptic current. [mV]
alpha: float, JaxArray, ndarray
Binding constant. Default 0.062
beta: float, JaxArray, ndarray
Unbinding constant. Default 3.57
cc_Mg: float, JaxArray, ndarray
Concentration of Magnesium ion. Default 1.2 [mM].
tau_decay: float, JaxArray, ndarray
The time constant of the synaptic decay phase. Default 100 [ms]
tau_rise: float, JaxArray, ndarray
@@ -1167,83 +788,53 @@ class NMDA(TwoEndConn):
self,
pre: NeuGroup,
post: NeuGroup,
conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]],
conn_type: str = 'dense',
g_max: Union[float, Tensor, Initializer, Callable] = 0.15,
delay_step: Union[int, Tensor, Initializer, Callable] = None,
E: Union[float, Tensor] = 0.,
cc_Mg: Union[float, Tensor] = 1.2,
alpha: Union[float, Tensor] = 0.062,
beta: Union[float, Tensor] = 3.57,
tau_decay: Union[float, Tensor] = 100.,
a: Union[float, Tensor] = 0.5,
tau_rise: Union[float, Tensor] = 2.,
conn: Union[TwoEndConnector, Array, Dict[str, Array]],
output: SynOut = MgBlock(E=0., alpha=0.062, beta=3.57, cc_Mg=1.2),
stp: Optional[SynSTP] = None,
comp_method: str = 'dense',
g_max: Union[float, Array, Initializer, Callable] = 0.15,
delay_step: Union[int, Array, Initializer, Callable] = None,
tau_decay: Union[float, Array] = 100.,
a: Union[float, Array] = 0.5,
tau_rise: Union[float, Array] = 2.,
method: str = 'exp_auto',

# other parameters
name: str = None,
mode: Mode = normal,
stop_spike_gradient: bool = False,
):
super(NMDA, self).__init__(pre=pre, post=post, conn=conn, name=name)
self.check_pre_attrs('spike')
self.check_post_attrs('input', 'V')

super(NMDA, self).__init__(pre=pre,
post=post,
conn=conn,
output=output,
stp=stp,
name=name,
mode=mode)
# parameters
self.E = E
self.alpha = alpha
self.beta = beta
self.cc_Mg = cc_Mg
# self.check_post_attrs('input', 'V')
self.tau_decay = tau_decay
self.tau_rise = tau_rise
self.a = a
if bm.size(a) != 1:
raise ValueError(f'"a" must be a scalar or a tensor with size of 1. But we got {a}')
if bm.size(E) != 1:
raise ValueError(f'"E" must be a scalar or a tensor with size of 1. But we got {E}')
if bm.size(alpha) != 1:
raise ValueError(f'"alpha" must be a scalar or a tensor with size of 1. But we got {alpha}')
if bm.size(beta) != 1:
raise ValueError(f'"beta" must be a scalar or a tensor with size of 1. But we got {beta}')
if bm.size(cc_Mg) != 1:
raise ValueError(f'"cc_Mg" must be a scalar or a tensor with size of 1. But we got {cc_Mg}')
if bm.size(tau_decay) != 1:
raise ValueError(f'"tau_decay" must be a scalar or a tensor with size of 1. But we got {tau_decay}')
if bm.size(tau_rise) != 1:
raise ValueError(f'"tau_rise" must be a scalar or a tensor with size of 1. But we got {tau_rise}')
self.comp_method = comp_method
self.stop_spike_gradient = stop_spike_gradient

# connections and weights
self.conn_type = conn_type
if conn_type not in ['sparse', 'dense']:
raise ValueError(f'"conn_type" must be in "sparse" and "dense", but we got {conn_type}')
if self.conn is None:
raise ValueError(f'Must provide "conn" when initialize the model {self.name}')
if isinstance(self.conn, One2One):
self.g_max = init_param(g_max, (self.pre.num,), allow_none=False)
self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo'
elif isinstance(self.conn, All2All):
self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False)
if bm.size(self.g_max) != 1:
self.weight_type = 'heter'
bm.fill_diagonal(self.g_max, 0.)
else:
self.weight_type = 'homo'
else:
if conn_type == 'sparse':
self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids')
self.g_max = init_param(g_max, self.post_ids.shape, allow_none=False)
self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo'
elif conn_type == 'dense':
self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False)
self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo'
if self.weight_type == 'homo':
self.conn_mat = self.conn.require('conn_mat')
else:
raise ValueError(f'Unknown connection type: {conn_type}')
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij')

# variables
self.g = bm.Variable(bm.zeros(self.pre.num, dtype=bm.float_))
self.x = bm.Variable(bm.zeros(self.pre.num, dtype=bm.float_))
self.g = variable(bm.zeros, mode, self.pre.num)
self.x = variable(bm.zeros, mode, self.pre.num)
self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)

# integral
self.integral = odeint(method=method, f=JointEq([self.dg, self.dx]))
self.integral = odeint(method=method, f=JointEq(self.dg, self.dx))

def dg(self, g, t, x):
return -g / self.tau_decay + self.a * x * (1 - g)
@@ -1251,40 +842,43 @@ class NMDA(TwoEndConn):
def dx(self, x, t):
return -x / self.tau_rise

def update(self, t, dt):
def reset_state(self, batch_size=None):
self.g.value = variable(bm.zeros, batch_size, self.pre.num)
self.x.value = variable(bm.zeros, batch_size, self.pre.num)
self.output.reset_state(batch_size)
if self.stp is not None: self.stp.reset_state(batch_size)

def update(self, tdi, pre_spike=None):
t, dt = tdi['t'], tdi['dt']
# delays
if self.delay_step is None:
delayed_pre_spike = self.pre.spike
else:
delayed_pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step)
self.update_delay(f"{self.pre.name}.spike", self.pre.spike)
if pre_spike is None:
pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step)
if self.stop_spike_gradient:
pre_spike = pre_spike.value if isinstance(pre_spike, bm.JaxArray) else pre_spike
pre_spike = stop_gradient(pre_spike)

# update sub-components
self.output.update(tdi)
if self.stp is not None: self.stp.update(tdi, pre_spike)

# update synapse variables
self.g.value, self.x.value = self.integral(self.g, self.x, t, dt=dt)
self.x += delayed_pre_spike
self.x += pre_spike

# post-synaptic value
assert self.weight_type in ['homo', 'heter']
assert self.conn_type in ['sparse', 'dense']
syn_value = self.g.value
if self.stp is not None: syn_value = self.stp(syn_value)
if isinstance(self.conn, All2All):
if self.weight_type == 'homo':
post_g = bm.sum(self.g)
if not self.conn.include_self:
post_g = post_g - self.g
post_g = post_g * self.g_max
else:
post_g = self.g @ self.g_max
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
post_g = self.g_max * self.g
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
else:
if self.conn_type == 'sparse':
post_g = bm.pre2post_sum(self.g, self.post.num, self.post_ids, self.pre_ids)
if self.comp_method == 'sparse':
f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask)
if isinstance(self.mode, BatchingMode): f = vmap(f)
post_vs = f(syn_value)
else:
if self.weight_type == 'homo':
post_g = (self.g_max * self.g) @ self.conn_mat
else:
post_g = self.g @ self.g_max
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)

# output
g_inf = 1 + self.cc_Mg / self.beta * bm.exp(-self.alpha * self.post.V)
self.post.input += post_g * (self.E - self.post.V) / g_inf
return self.output(post_vs)

+ 357
- 103
brainpy/dyn/synapses/biological_models.py View File

@@ -1,22 +1,29 @@
# -*- coding: utf-8 -*-

from typing import Union, Dict, Callable
import warnings
from typing import Union, Dict, Callable, Optional

from jax import vmap
from jax.lax import stop_gradient

import brainpy.math as bm
from brainpy.connect import TwoEndConnector, All2All, One2One
from brainpy.dyn.base import NeuGroup, TwoEndConn
from brainpy.initialize import Initializer, init_param
from brainpy.integrators import odeint
from brainpy.types import Tensor
from brainpy.dyn.base import NeuGroup, TwoEndConn, SynSTP, SynOut
from brainpy.dyn.synouts import COBA, MgBlock
from brainpy.initialize import Initializer, variable
from brainpy.integrators import odeint, JointEq
from brainpy.types import Array
from brainpy.modes import Mode, BatchingMode, TrainingMode, normal, batching, training

__all__ = [
'AMPA',
'GABAa',
'BioNMDA',
]


class AMPA(TwoEndConn):
r"""AMPA conductance-based synapse model.
r"""AMPA synapse model.

**Model Descriptions**

@@ -62,11 +69,12 @@ class AMPA(TwoEndConn):
:include-source: True

>>> import brainpy as bp
>>> from brainpy.dyn import neurons, synapses
>>> import matplotlib.pyplot as plt
>>>
>>> neu1 = bp.dyn.HH(1)
>>> neu2 = bp.dyn.HH(1)
>>> syn1 = bp.dyn.AMPA(neu1, neu2, bp.connect.All2All())
>>> neu1 = neurons.HH(1)
>>> neu2 = neurons.HH(1)
>>> syn1 = synapses.AMPA(neu1, neu2, bp.connect.All2All())
>>> net = bp.dyn.Network(pre=neu1, syn=syn1, post=neu2)
>>>
>>> runner = bp.dyn.DSRunner(net, inputs=[('pre.input', 5.)], monitors=['pre.V', 'post.V', 'syn.g'])
@@ -91,13 +99,18 @@ class AMPA(TwoEndConn):
The post-synaptic neuron group.
conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector
The synaptic connections.
conn_type: str
comp_method: str
The connection type used for model speed optimization. It can be
`sparse` and `dense`. The default is `dense`.
delay_step: int, ndarray, JaxArray, Initializer, Callable
The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
E: float, JaxArray, ndarray
The reversal potential for the synaptic current. [mV]

.. deprecated:: 2.1.13
`E` is deprecated in AMPA model. Please define `E` with brainpy.dyn.synouts.COBA.
This parameter will be removed since 2.2.0

g_max: float, ndarray, JaxArray, Initializer, Callable
The synaptic strength (the maximum conductance). Default is 1.
alpha: float, JaxArray, ndarray
@@ -126,30 +139,38 @@ class AMPA(TwoEndConn):
self,
pre: NeuGroup,
post: NeuGroup,
conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]],
conn_type: str = 'dense',
g_max: Union[float, Tensor, Initializer, Callable] = 0.42,
delay_step: Union[int, Tensor, Initializer, Callable] = None,
E: Union[float, Tensor] = 0.,
alpha: Union[float, Tensor] = 0.98,
beta: Union[float, Tensor] = 0.18,
T: Union[float, Tensor] = 0.5,
T_duration: Union[float, Tensor] = 0.5,
conn: Union[TwoEndConnector, Array, Dict[str, Array]],
output: SynOut = COBA(E=0.),
stp: Optional[SynSTP] = None,
comp_method: str = 'dense',
g_max: Union[float, Array, Initializer, Callable] = 0.42,
delay_step: Union[int, Array, Initializer, Callable] = None,
alpha: float = 0.98,
beta: float = 0.18,
T: float = 0.5,
T_duration: float = 0.5,
method: str = 'exp_auto',
name: str = None

# other parameters
name: str = None,
mode: Mode = normal,
stop_spike_gradient: bool = False,
):
super(AMPA, self).__init__(pre=pre, post=post, conn=conn, name=name)
self.check_pre_attrs('spike')
self.check_post_attrs('input', 'V')
super(AMPA, self).__init__(pre=pre,
post=post,
conn=conn,
output=output,
stp=stp,
name=name,
mode=mode)

# parameters
self.E = E
self.stop_spike_gradient = stop_spike_gradient
self.comp_method = comp_method
self.alpha = alpha
self.beta = beta
self.T = T
self.T_duration = T_duration
if bm.size(E) != 1:
raise ValueError(f'"E" must be a scalar or a tensor with size of 1. But we got {E}')
if bm.size(alpha) != 1:
raise ValueError(f'"alpha" must be a scalar or a tensor with size of 1. But we got {alpha}')
if bm.size(beta) != 1:
@@ -160,92 +181,68 @@ class AMPA(TwoEndConn):
raise ValueError(f'"T_duration" must be a scalar or a tensor with size of 1. But we got {T_duration}')

# connection
self.conn_type = conn_type
if conn_type not in ['sparse', 'dense']:
raise ValueError(f'"conn_type" must be in "sparse" and "dense", but we got {conn_type}')
if self.conn is None:
raise ValueError(f'Must provide "conn" when initialize the model {self.name}')
if isinstance(self.conn, One2One):
self.g_max = init_param(g_max, (self.pre.num,), allow_none=False)
self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo'
elif isinstance(self.conn, All2All):
self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False)
if bm.size(self.g_max) != 1:
self.weight_type = 'heter'
bm.fill_diagonal(self.g_max, 0.)
else:
self.weight_type = 'homo'
else:
if conn_type == 'sparse':
self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids')
self.g_max = init_param(g_max, self.post_ids.shape, allow_none=False)
self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo'
elif conn_type == 'dense':
self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False)
self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo'
if self.weight_type == 'homo':
self.conn_mat = self.conn.require('conn_mat')
else:
raise ValueError(f'Unknown connection type: {conn_type}')
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij')

# variables
self.g = bm.Variable(bm.zeros(self.pre.num))
self.spike_arrival_time = bm.Variable(bm.ones(self.pre.num) * -1e7)
self.delay_step = self.register_delay(f"{self.pre.name}.spike",
delay_step=delay_step,
delay_target=self.pre.spike)
self.g = variable(bm.zeros, mode, self.pre.num)
self.spike_arrival_time = variable(lambda s: bm.ones(s) * -1e7, mode, self.pre.num)
self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)

# functions
self.integral = odeint(method=method, f=self.dg)

def reset(self):
self.g.value = bm.zeros(self.pre.num)
if self.delay_step is not None:
self.reset_delay(f"{self.pre.name}.spike", self.pre.spike)
def reset_state(self, batch_size=None):
self.g = variable(bm.zeros, batch_size, self.pre.num)
self.spike_arrival_time = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.pre.num)
self.output.reset_state(batch_size)
if self.stp is not None: self.stp.reset_state(batch_size)

def dg(self, g, t, TT):
dg = self.alpha * TT * (1 - g) - self.beta * g
return dg

def update(self, t, dt):
def update(self, tdi, pre_spike=None):
t, dt = tdi['t'], tdi['dt']

# delays
if self.delay_step is None:
pre_spike = self.pre.spike
else:
if pre_spike is None:
pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step)
self.update_delay(f"{self.pre.name}.spike", self.pre.spike)
if self.stop_spike_gradient:
pre_spike = pre_spike.value if isinstance(pre_spike, bm.JaxArray) else pre_spike
pre_spike = stop_gradient(pre_spike)

# spike arrival time
# update sub-components
self.output.update(tdi)
if self.stp is not None: self.stp.update(tdi, pre_spike)

# update synaptic variables
self.spike_arrival_time.value = bm.where(pre_spike, t, self.spike_arrival_time)
if isinstance(self.mode, TrainingMode):
self.spike_arrival_time.value = stop_gradient(self.spike_arrival_time.value)
TT = ((t - self.spike_arrival_time) < self.T_duration) * self.T
self.g.value = self.integral(self.g, t, TT, dt)

# post-synaptic values
TT = ((t - self.spike_arrival_time) < self.T_duration) * self.T
self.g.value = self.integral(self.g, t, TT, dt=dt)
if isinstance(self.conn, One2One):
post_g = self.g_max * self.g
elif isinstance(self.conn, All2All):
if self.weight_type == 'homo':
post_g = bm.sum(self.g)
if not self.conn.include_self:
post_g = post_g - self.g
post_g = post_g * self.g_max
else:
post_g = self.g @ self.g_max
syn_value = self.g.value
if self.stp is not None: syn_value = self.stp(syn_value)
if isinstance(self.conn, All2All):
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
else:
if self.conn_type == 'sparse':
post_g = bm.pre2post_sum(self.g, self.post.num, self.post_ids, self.pre_ids)
if self.comp_method == 'sparse':
f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask)
if isinstance(self.mode, BatchingMode): f = vmap(f)
post_vs = f(syn_value)
else:
if self.weight_type == 'homo':
post_g = (self.g_max * self.g) @ self.conn_mat
else:
post_g = self.g @ self.g_max
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)

# output
self.post.input -= post_g * (self.post.V - self.E)
return self.output(post_vs)


class GABAa(AMPA):
r"""GABAa conductance-based synapse model.
r"""GABAa synapse model.

**Model Descriptions**

@@ -277,13 +274,18 @@ class GABAa(AMPA):
The post-synaptic neuron group.
conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector
The synaptic connections.
conn_type: str
comp_method: str
The connection type used for model speed optimization. It can be
`sparse` and `dense`. The default is `dense`.
delay_step: int, ndarray, JaxArray, Initializer, Callable
The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
E: float, JaxArray, ndarray
The reversal potential for the synaptic current. [mV]

.. deprecated:: 2.1.13
`E` is deprecated in AMPA model. Please define `E` with brainpy.dyn.synouts.COBA.
This parameter will be removed since 2.2.0

g_max: float, ndarray, JaxArray, Initializer, Callable
The synaptic strength (the maximum conductance). Default is 1.
alpha: float, JaxArray, ndarray
@@ -311,26 +313,278 @@ class GABAa(AMPA):
self,
pre: NeuGroup,
post: NeuGroup,
conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]],
conn_type: str = 'dense',
g_max: Union[float, Tensor, Initializer, Callable] = 0.04,
delay_step: Union[int, Tensor, Initializer, Callable] = None,
E: Union[float, Tensor] = -80.,
alpha: Union[float, Tensor] = 0.53,
beta: Union[float, Tensor] = 0.18,
T: Union[float, Tensor] = 1.,
T_duration: Union[float, Tensor] = 1.,
conn: Union[TwoEndConnector, Array, Dict[str, Array]],
output: SynOut = COBA(E=-80.),
stp: Optional[SynSTP] = None,
comp_method: str = 'dense',
g_max: Union[float, Array, Initializer, Callable] = 0.04,
delay_step: Union[int, Array, Initializer, Callable] = None,
alpha: Union[float, Array] = 0.53,
beta: Union[float, Array] = 0.18,
T: Union[float, Array] = 1.,
T_duration: Union[float, Array] = 1.,
method: str = 'exp_auto',
name: str = None

# other parameters
name: str = None,
mode: Mode = normal,
stop_spike_gradient: bool = False,

# deprecated
E: Union[float, Array] = None,
):
super(GABAa, self).__init__(pre, post, conn,
conn_type=conn_type,
super(GABAa, self).__init__(pre=pre,
post=post,
conn=conn,
output=output,
stp=stp,
comp_method=comp_method,
delay_step=delay_step,
g_max=g_max,
E=E,
alpha=alpha,
beta=beta,
T=T,
T_duration=T_duration,
method=method,
name=name)
name=name,
mode=mode,
stop_spike_gradient=stop_spike_gradient, )


class BioNMDA(TwoEndConn):
r"""Biological NMDA synapse model.

**Model Descriptions**

The NMDA receptor is a glutamate receptor and ion channel found in neurons.
The NMDA receptor is one of three types of ionotropic glutamate receptors,
the other two being AMPA and kainate receptors.

The NMDA receptor mediated conductance depends on the postsynaptic voltage.
The voltage dependence is due to the blocking of the pore of the NMDA receptor
from the outside by a positively charged magnesium ion. The channel is
nearly completely blocked at resting potential, but the magnesium block is
relieved if the cell is depolarized. The fraction of channels :math:`g_{\infty}`
that are not blocked by magnesium can be fitted to

.. math::

g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\a V}
\frac{[{Mg}^{2+}]_{o}} {\b})^{-1}

Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration,
usually 1 mM. Thus, the channel acts as a
"coincidence detector" and only once both of these conditions are met, the
channel opens and it allows positively charged ions (cations) to flow through
the cell membrane [2]_.

If we make the approximation that the magnesium block changes
instantaneously with voltage and is independent of the gating of the channel,
the net NMDA receptor-mediated synaptic current is given by

.. math::

I_{syn} = g_\mathrm{NMDA}(t) (V(t)-E) \cdot g_{\infty}

where :math:`V(t)` is the post-synaptic neuron potential, :math:`E` is the
reversal potential.

Simultaneously, the kinetics of synaptic state :math:`g` is determined by a 2nd-order kinetics [1]_:

.. math::

& g_\mathrm{NMDA} (t) = g_{max} g \\
& \frac{d g}{dt} = \alpha_1 x (1 - g) - \beta_1 g \\
& \frac{d x}{dt} = \alpha_2 [T] (1 - x) - \beta_2 x

where :math:`\alpha_1, \beta_1` refers to the conversion rate of variable g and
:math:`\alpha_2, \beta_2` refers to the conversion rate of variable x.

The NMDA receptor has been thought to be very important for controlling
synaptic plasticity and mediating learning and memory functions [3]_.

.. plot::
:include-source: True

>>> import brainpy as bp
>>> from brainpy.dyn import neurons, synapses
>>> import matplotlib.pyplot as plt
>>>
>>> neu1 = neurons.HH(1)
>>> neu2 = neurons.HH(1)
>>> syn1 = synapses.BioNMDA(neu1, neu2, bp.connect.All2All(), E=0.)
>>> net = bp.dyn.Network(pre=neu1, syn=syn1, post=neu2)
>>>
>>> runner = bp.dyn.DSRunner(net, inputs=[('pre.input', 5.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.x'])
>>> runner.run(150.)
>>>
>>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
>>> fig.add_subplot(gs[0, 0])
>>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V')
>>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V')
>>> plt.legend()
>>>
>>> fig.add_subplot(gs[1, 0])
>>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g')
>>> plt.plot(runner.mon.ts, runner.mon['syn.x'], label='x')
>>> plt.legend()
>>> plt.show()

Parameters
----------
pre: NeuGroup
The pre-synaptic neuron group.
post: NeuGroup
The post-synaptic neuron group.
conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector
The synaptic connections.
comp_method: str
The connection type used for model speed optimization. It can be
`sparse` and `dense`. The default is `dense`.
delay_step: int, ndarray, JaxArray, Initializer, Callable
The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
g_max: float, ndarray, JaxArray, Initializer, Callable
The synaptic strength (the maximum conductance). Default is 1.
alpha1: float, JaxArray, ndarray
The conversion rate of g from inactive to active. Default 2 ms^-1.
beta1: float, JaxArray, ndarray
The conversion rate of g from active to inactive. Default 0.01 ms^-1.
alpha2: float, JaxArray, ndarray
The conversion rate of x from inactive to active. Default 1 ms^-1.
beta2: float, JaxArray, ndarray
The conversion rate of x from active to inactive. Default 0.5 ms^-1.
name: str
The name of this synaptic projection.
method: str
The numerical integration methods.

References
----------

.. [1] Devaney A J . Mathematical Foundations of Neuroscience[M].
Springer New York, 2010: 162.
.. [2] Furukawa, Hiroyasu, Satinder K. Singh, Romina Mancusso, and
Eric Gouaux. "Subunit arrangement and function in NMDA receptors."
Nature 438, no. 7065 (2005): 185-192.
.. [3] Li, F. and Tsien, J.Z., 2009. Memory and the NMDA receptors. The New
England journal of medicine, 361(3), p.302.
.. [4] https://en.wikipedia.org/wiki/NMDA_receptor

"""

def __init__(
self,
pre: NeuGroup,
post: NeuGroup,
conn: Union[TwoEndConnector, Array, Dict[str, Array]],
output: SynOut = MgBlock(E=0.),
stp: Optional[SynSTP] = None,
comp_method: str = 'dense',
g_max: Union[float, Array, Initializer, Callable] = 0.15,
delay_step: Union[int, Array, Initializer, Callable] = None,
alpha1: Union[float, Array] = 2.,
beta1: Union[float, Array] = 0.01,
alpha2: Union[float, Array] = 1.,
beta2: Union[float, Array] = 0.5,
T_0: Union[float, Array] = 1.,
T_dur: Union[float, Array] = 0.5,
method: str = 'exp_auto',

# other parameters
mode: Mode = normal,
name: str = None,
stop_spike_gradient: bool = False,
):
super(BioNMDA, self).__init__(pre=pre,
post=post,
conn=conn,
output=output,
stp=stp,
name=name,
mode=mode)

# parameters
self.beta1 = beta1
self.beta2 = beta2
self.alpha1 = alpha1
self.alpha2 = alpha2
self.T_0 = T_0
self.T_dur = T_dur
if bm.size(alpha1) != 1:
raise ValueError(f'"alpha1" must be a scalar or a tensor with size of 1. But we got {alpha1}')
if bm.size(beta1) != 1:
raise ValueError(f'"beta1" must be a scalar or a tensor with size of 1. But we got {beta1}')
if bm.size(alpha2) != 1:
raise ValueError(f'"alpha2" must be a scalar or a tensor with size of 1. But we got {alpha2}')
if bm.size(beta2) != 1:
raise ValueError(f'"beta2" must be a scalar or a tensor with size of 1. But we got {beta2}')
if bm.size(T_0) != 1:
raise ValueError(f'"T_0" must be a scalar or a tensor with size of 1. But we got {T_0}')
if bm.size(T_dur) != 1:
raise ValueError(f'"T_dur" must be a scalar or a tensor with size of 1. But we got {T_dur}')
self.comp_method = comp_method
self.stop_spike_gradient = stop_spike_gradient

# connections and weights
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij')

# variables
self.g = variable(bm.zeros, mode, self.pre.num)
self.x = variable(bm.zeros, mode, self.pre.num)
self.spike_arrival_time = variable(lambda s: bm.ones(s) * -1e7, mode, self.pre.num)
self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)

# integral
self.integral = odeint(method=method, f=JointEq([self.dg, self.dx]))

def reset_state(self, batch_size=None):
self.g = variable(bm.zeros, batch_size, self.pre.num)
self.x = variable(bm.zeros, batch_size, self.pre.num)
self.spike_arrival_time = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.pre.num)
self.output.reset_state(batch_size)
if self.stp is not None: self.stp.reset_state(batch_size)

def dg(self, g, t, x):
return self.alpha1 * x * (1 - g) - self.beta1 * g

def dx(self, x, t, T):
return self.alpha2 * T * (1 - x) - self.beta2 * x

def update(self, tdi, pre_spike=None):
t, dt = tdi['t'], tdi['dt']

# pre-synaptic spikes
if pre_spike is None:
pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step)
if self.stop_spike_gradient:
pre_spike = pre_spike.value if isinstance(pre_spike, bm.JaxArray) else pre_spike
pre_spike = stop_gradient(pre_spike)

# update sub-components
self.output.update(tdi)
if self.stp is not None: self.stp.update(tdi, pre_spike)

# update synapse variables
self.spike_arrival_time.value = bm.where(pre_spike, t, self.spike_arrival_time)
if isinstance(self.mode, TrainingMode):
self.spike_arrival_time.value = stop_gradient(self.spike_arrival_time.value)
T = ((t - self.spike_arrival_time) < self.T_dur) * self.T_0
self.g.value, self.x.value = self.integral(self.g, self.x, t, T, dt)

# post-synaptic value
syn_value = self.g.value
if self.stp is not None: syn_value = self.stp(syn_value)
if isinstance(self.conn, All2All):
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
else:
if self.comp_method == 'sparse':
f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask)
if isinstance(self.mode, BatchingMode): f = vmap(f)
post_vs = f(syn_value)
else:
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)

# output
return self.output(post_vs)

+ 258
- 0
brainpy/dyn/synapses/compat.py View File

@@ -0,0 +1,258 @@
# -*- coding: utf-8 -*-
import warnings
from typing import Union, Dict, Callable

from brainpy.connect import TwoEndConnector
from brainpy.dyn.base import NeuGroup
from brainpy.initialize import Initializer
from brainpy.types import Array
from .abstract_models import Delta, Exponential, DualExponential, NMDA
from ..synouts import COBA, CUBA

__all__ = [
'DeltaSynapse',
'ExpCUBA',
'ExpCOBA',
'DualExpCUBA',
'DualExpCOBA',
'AlphaCUBA',
'AlphaCOBA',
'NMDA',
]


class DeltaSynapse(Delta):
"""Delta synapse.

.. deprecated:: 2.1.13
Please use "brainpy.dyn.synapses.Delta" instead.

"""

def __init__(
self,
pre: NeuGroup,
post: NeuGroup,
conn: Union[TwoEndConnector, Array, Dict[str, Array]],
conn_type: str = 'sparse',
weights: Union[float, Array, Initializer, Callable] = 1.,
delay_step: Union[float, Array, Initializer, Callable] = None,
post_input_key: str = 'V',
post_has_ref: bool = False,
name: str = None,
):
warnings.warn('Please use "brainpy.dyn.synapses.Delta" instead.', DeprecationWarning)
super(DeltaSynapse, self).__init__(pre=pre,
post=post,
conn=conn,
output=CUBA(),
name=name,
comp_method=conn_type,
g_max=weights,
delay_step=delay_step,
post_input_key=post_input_key,
post_ref_key='refractory' if post_has_ref else None)


class ExpCUBA(Exponential):
r"""Current-based exponential decay synapse model.

.. deprecated:: 2.1.13
Please use "brainpy.dyn.synapses.Exponential" instead.

"""

def __init__(
self,
pre: NeuGroup,
post: NeuGroup,
conn: Union[TwoEndConnector, Array, Dict[str, Array]],
conn_type: str = 'sparse',
g_max: Union[float, Array, Initializer, Callable] = 1.,
delay_step: Union[int, Array, Initializer, Callable] = None,
tau: Union[float, Array] = 8.0,
name: str = None,
method: str = 'exp_auto',
):
super(ExpCUBA, self).__init__(pre=pre,
post=post,
conn=conn,
name=name,
comp_method=conn_type,
g_max=g_max,
delay_step=delay_step,
tau=tau,
method=method,
output=CUBA())


class ExpCOBA(Exponential):
"""Conductance-based exponential decay synapse model.

.. deprecated:: 2.1.13
Please use "brainpy.dyn.synapses.Exponential" instead.
"""

def __init__(
self,
pre: NeuGroup,
post: NeuGroup,
# connection
conn: Union[TwoEndConnector, Array, Dict[str, Array]],
conn_type: str = 'sparse',
# connection strength
g_max: Union[float, Array, Initializer, Callable] = 1.,
# synapse parameter
tau: Union[float, Array] = 8.0,
E: Union[float, Array] = 0.,
# synapse delay
delay_step: Union[int, Array, Initializer, Callable] = None,
# others
method: str = 'exp_auto',
name: str = None
):
super(ExpCOBA, self).__init__(pre=pre,
post=post,
conn=conn,
comp_method=conn_type,
g_max=g_max,
delay_step=delay_step,
tau=tau,
method=method,
name=name,
output=COBA(E=E))


class DualExpCUBA(DualExponential):
r"""Current-based dual exponential synapse model.

.. deprecated:: 2.1.13
Please use "brainpy.dyn.synapses.DualExponential" instead.

"""

def __init__(
self,
pre: NeuGroup,
post: NeuGroup,
conn: Union[TwoEndConnector, Array, Dict[str, Array]],
conn_type: str = 'dense',
g_max: Union[float, Array, Initializer, Callable] = 1.,
tau_decay: Union[float, Array] = 10.0,
tau_rise: Union[float, Array] = 1.,
delay_step: Union[int, Array, Initializer, Callable] = None,
method: str = 'exp_auto',
name: str = None
):
super(DualExpCUBA, self).__init__(pre=pre,
post=post,
conn=conn,
comp_method=conn_type,
g_max=g_max,
tau_decay=tau_decay,
tau_rise=tau_rise,
delay_step=delay_step,
method=method,
name=name,
output=CUBA())


class DualExpCOBA(DualExponential):
"""Conductance-based dual exponential synapse model.


.. deprecated:: 2.1.13
Please use "brainpy.dyn.synapses.DualExponential" instead.

"""

def __init__(
self,
pre: NeuGroup,
post: NeuGroup,
conn: Union[TwoEndConnector, Array, Dict[str, Array]],
conn_type: str = 'dense',
g_max: Union[float, Array, Initializer, Callable] = 1.,
delay_step: Union[int, Array, Initializer, Callable] = None,
tau_decay: Union[float, Array] = 10.0,
tau_rise: Union[float, Array] = 1.,
E: Union[float, Array] = 0.,
method: str = 'exp_auto',
name: str = None
):
super(DualExpCOBA, self).__init__(pre=pre,
post=post,
conn=conn,
comp_method=conn_type,
g_max=g_max,
tau_decay=tau_decay,
tau_rise=tau_rise,
delay_step=delay_step,
method=method,
name=name,
output=COBA(E=E))


class AlphaCUBA(DualExpCUBA):
r"""Current-based alpha synapse model.

.. deprecated:: 2.1.13
Please use "brainpy.dyn.synapses.Alpha" instead.

"""

def __init__(
self,
pre: NeuGroup,
post: NeuGroup,
conn: Union[TwoEndConnector, Array, Dict[str, Array]],
conn_type: str = 'dense',
g_max: Union[float, Array, Initializer, Callable] = 1.,
delay_step: Union[int, Array, Initializer, Callable] = None,
tau_decay: Union[float, Array] = 10.0,
method: str = 'exp_auto',
name: str = None
):
super(AlphaCUBA, self).__init__(pre=pre,
post=post,
conn=conn,
conn_type=conn_type,
delay_step=delay_step,
g_max=g_max,
tau_decay=tau_decay,
tau_rise=tau_decay,
method=method,
name=name)


class AlphaCOBA(DualExpCOBA):
"""Conductance-based alpha synapse model.

.. deprecated:: 2.1.13
Please use "brainpy.dyn.synapses.Alpha" instead.

"""

def __init__(
self,
pre: NeuGroup,
post: NeuGroup,
conn: Union[TwoEndConnector, Array, Dict[str, Array]],
conn_type: str = 'dense',
g_max: Union[float, Array, Callable, Initializer] = 1.,
delay_step: Union[int, Array, Initializer, Callable] = None,
tau_decay: Union[float, Array] = 10.0,
E: Union[float, Array] = 0.,
method: str = 'exp_auto',
name: str = None
):
super(AlphaCOBA, self).__init__(pre=pre,
post=post,
conn=conn,
conn_type=conn_type,
delay_step=delay_step,
g_max=g_max, E=E,
tau_decay=tau_decay,
tau_rise=tau_decay,
method=method,
name=name)

brainpy/dyn/rates/couplings.py → brainpy/dyn/synapses/delay_couplings.py View File

@@ -6,10 +6,13 @@ import jax.numpy as jnp
from jax import vmap

import brainpy.math as bm
from brainpy.dyn.base import DynamicalSystem
from brainpy.dyn.base import SynConn, SynOut
from brainpy.dyn.synouts import CUBA
from brainpy.initialize import Initializer
from brainpy.tools.checking import check_sequence, check_integer
from brainpy.types import Tensor
from brainpy.dyn.neurons.input_groups import InputGroup, OutputGroup
from brainpy.modes import Mode, TrainingMode, normal
from brainpy.tools.checking import check_sequence
from brainpy.types import Array

__all__ = [
'DelayCoupling',
@@ -18,14 +21,14 @@ __all__ = [
]


class DelayCoupling(DynamicalSystem):
class DelayCoupling(SynConn):
"""Delay coupling.

Parameters
----------
delay_var: Variable
The delay variable.
target_var: Variable, sequence of Variable
var_to_output: Variable, sequence of Variable
The target variables to output.
conn_mat: JaxArray, ndarray
The connection matrix.
@@ -40,14 +43,18 @@ class DelayCoupling(DynamicalSystem):
def __init__(
self,
delay_var: bm.Variable,
target_var: Union[bm.Variable, Sequence[bm.Variable]],
conn_mat: Tensor,
var_to_output: Union[bm.Variable, Sequence[bm.Variable]],
conn_mat: Array,
required_shape: Tuple[int, ...],
delay_steps: Optional[Union[int, Tensor, Initializer, Callable]] = None,
initial_delay_data: Union[Initializer, Callable, Tensor, float, int, bool] = None,
name: str = None
delay_steps: Optional[Union[int, Array, Initializer, Callable]] = None,
initial_delay_data: Union[Initializer, Callable, Array, float, int, bool] = None,
name: str = None,
mode: Mode = normal,
):
super(DelayCoupling, self).__init__(name=name)
super(DelayCoupling, self).__init__(name=name,
mode=mode,
pre=InputGroup(1),
post=OutputGroup(1))

# delay variable
if not isinstance(delay_var, bm.Variable):
@@ -56,10 +63,10 @@ class DelayCoupling(DynamicalSystem):
self.delay_var = delay_var

# output variables
if isinstance(target_var, bm.Variable):
target_var = [target_var]
check_sequence(target_var, 'output_var', elem_type=bm.Variable, allow_none=False)
self.output_var = target_var
if isinstance(var_to_output, bm.Variable):
var_to_output = [var_to_output]
check_sequence(var_to_output, 'output_var', elem_type=bm.Variable, allow_none=False)
self.output_var = var_to_output

# Connection matrix
self.conn_mat = bm.asarray(conn_mat)
@@ -72,41 +79,42 @@ class DelayCoupling(DynamicalSystem):
if delay_steps is None:
self.delay_steps = None
self.delay_type = 'none'
num_delay_step = 0
elif isinstance(delay_steps, int):
self.delay_steps = delay_steps
num_delay_step = delay_steps
check_integer(delay_steps, 'delay_steps', min_bound=0, allow_none=False)
self.delay_type = 'int'
num_delay_step = None
elif callable(delay_steps):
delay_steps = delay_steps(required_shape)
if delay_steps.dtype not in [bm.int32, bm.int64, bm.uint32, bm.uint64]:
raise ValueError(f'"delay_steps" must be integer typed. But we got {delay_steps.dtype}')
self.delay_steps = delay_steps
self.delay_type = 'array'
num_delay_step = int(self.delay_steps.max())
num_delay_step = self.delay_steps.max()
elif isinstance(delay_steps, (bm.JaxArray, jnp.ndarray)):
if delay_steps.dtype not in [bm.int32, bm.int64, bm.uint32, bm.uint64]:
raise ValueError(f'"delay_steps" must be integer typed. But we got {delay_steps.dtype}')
if delay_steps.shape != required_shape:
raise ValueError(f'we expect the delay matrix has the shape of {required_shape}. '
f'While we got {delay_steps.shape}.')
if delay_steps.ndim == 0:
self.delay_type = 'int'
else:
self.delay_type = 'array'
if delay_steps.shape != required_shape:
raise ValueError(f'we expect the delay matrix has the shape of '
f'(pre.num, post.num), i.e., {required_shape}. '
f'While we got {delay_steps.shape}.')
self.delay_steps = delay_steps
self.delay_type = 'array'
num_delay_step = int(self.delay_steps.max())
num_delay_step = self.delay_steps.max()
elif isinstance(delay_steps, int):
self.delay_steps = delay_steps
num_delay_step = delay_steps
self.delay_type = 'int'
else:
raise ValueError(f'Unknown type of delay steps: {type(delay_steps)}')

# delay variables
if self.delay_type != 'none':
self.register_delay(f'delay_{id(delay_var)}',
delay_step=num_delay_step,
delay_target=delay_var,
initial_delay_data=initial_delay_data)
_ = self.register_delay(f'delay_{id(delay_var)}',
delay_step=num_delay_step,
delay_target=delay_var,
initial_delay_data=initial_delay_data)

def reset(self):
if self.delay_steps is not None:
self.reset_delay(f'delay_{id(self.delay_var)}', self.delay_var)
def reset_state(self, batch_size=None):
pass


class DiffusiveCoupling(DelayCoupling):
@@ -115,7 +123,7 @@ class DiffusiveCoupling(DelayCoupling):
This class simulates the model of::

coupling = g * (delayed_coupling_var1 - coupling_var2)
output_var += coupling
target_var += coupling


Examples
@@ -123,10 +131,10 @@ class DiffusiveCoupling(DelayCoupling):

>>> import brainpy as bp
>>> from brainpy.dyn import rates
>>> areas = rates.FHN(80, x_ou_sigma=0.01, y_ou_sigma=0.01, name='fhn')
>>> conn = rates.DiffusiveCoupling(areas.x, areas.x, areas.input,
>>> conn_mat=Cmat, delay_steps=Dmat,
>>> initial_delay_data=bp.init.Uniform(0, 0.05))
>>> areas = bp.rates.FHN(80, x_ou_sigma=0.01, y_ou_sigma=0.01, name='fhn')
>>> conn = bp.synapses.DiffusiveCoupling(areas.x, areas.x, areas.input,
>>> conn_mat=Cmat, delay_steps=Dmat,
>>> initial_delay_data=bp.init.Uniform(0, 0.05))
>>> net = bp.dyn.Network(areas, conn)

Parameters
@@ -135,7 +143,7 @@ class DiffusiveCoupling(DelayCoupling):
The first coupling variable, used for delay.
coupling_var2: Variable
Another coupling variable.
target_var: Variable, sequence of Variable
var_to_output: Variable, sequence of Variable
The target variables to output.
conn_mat: JaxArray, ndarray
The connection matrix.
@@ -151,11 +159,12 @@ class DiffusiveCoupling(DelayCoupling):
self,
coupling_var1: bm.Variable,
coupling_var2: bm.Variable,
target_var: Union[bm.Variable, Sequence[bm.Variable]],
conn_mat: Tensor,
delay_steps: Optional[Union[int, Tensor, Initializer, Callable]] = None,
initial_delay_data: Union[Initializer, Callable, Tensor, float, int, bool] = None,
name: str = None
var_to_output: Union[bm.Variable, Sequence[bm.Variable]],
conn_mat: Array,
delay_steps: Optional[Union[int, Array, Initializer, Callable]] = None,
initial_delay_data: Union[Initializer, Callable, Array, float, int, bool] = None,
name: str = None,
mode: Mode = normal,
):
if not isinstance(coupling_var1, bm.Variable):
raise ValueError(f'"coupling_var1" must be an instance of brainpy.math.Variable. '
@@ -172,60 +181,62 @@ class DiffusiveCoupling(DelayCoupling):

super(DiffusiveCoupling, self).__init__(
delay_var=coupling_var1,
target_var=target_var,
var_to_output=var_to_output,
conn_mat=conn_mat,
required_shape=(coupling_var1.size, coupling_var2.size),
delay_steps=delay_steps,
initial_delay_data=initial_delay_data,
name=name
name=name,
mode=mode,
)

self.coupling_var1 = coupling_var1
self.coupling_var2 = coupling_var2

def update(self, t, dt):
# delay variable
if self.delay_type != 'none':
delay_var: bm.LengthDelay = self.global_delay_vars[f'delay_{id(self.delay_var)}']

def update(self, tdi):
# delays
if self.delay_type == 'none':
diffusive = bm.expand_dims(self.coupling_var1, axis=1) - self.coupling_var2
diffusive = (self.conn_mat * diffusive).sum(axis=0)
axis = self.coupling_var1.ndim
delay_var: bm.LengthDelay = self.global_delay_data[f'delay_{id(self.delay_var)}'][0]
if self.delay_steps is None:
diffusive = (bm.expand_dims(self.coupling_var1, axis=axis) -
bm.expand_dims(self.coupling_var2, axis=axis - 1))
diffusive = (self.conn_mat * diffusive).sum(axis=axis - 1)
elif self.delay_type == 'array':
f = vmap(lambda i: delay_var(self.delay_steps[i], bm.arange(self.coupling_var1.size))) # (pre.num,)
delays = f(bm.arange(self.coupling_var2.size).value)
diffusive = delays.T - self.coupling_var2 # (post.num, pre.num)
diffusive = (self.conn_mat * diffusive).sum(axis=0)
if isinstance(self.mode, TrainingMode):
indices = (slice(None, None, None), bm.arange(self.coupling_var1.size),)
else:
indices = (bm.arange(self.coupling_var1.size),)
f = vmap(lambda steps: delay_var(steps, *indices), in_axes=1) # (..., pre.num)
delays = f(self.delay_steps) # (..., post.num, pre.num)
diffusive = (bm.moveaxis(delays, axis - 1, axis) -
bm.expand_dims(self.coupling_var2, axis=axis - 1)) # (..., pre.num, post.num)
diffusive = (self.conn_mat * diffusive).sum(axis=axis - 1)
elif self.delay_type == 'int':
delayed_var = delay_var(self.delay_steps)
diffusive = bm.expand_dims(delayed_var, axis=1) - self.coupling_var2
diffusive = (self.conn_mat * diffusive).sum(axis=0)
delayed_data = delay_var(self.delay_steps) # (..., pre.num)
diffusive = (bm.expand_dims(delayed_data, axis=axis) -
bm.expand_dims(self.coupling_var2, axis=axis - 1)) # (..., pre.num, post.num)
diffusive = (self.conn_mat * diffusive).sum(axis=axis - 1)
else:
raise ValueError
raise ValueError(f'Unknown delay type {self.delay_type}')

# output to target variable
for target in self.output_var:
target.value += diffusive

# update
if self.delay_type != 'none':
delay_var.update(self.delay_var)


class AdditiveCoupling(DelayCoupling):
"""Additive coupling.

This class simulates the model of::

coupling = g * delayed_coupling_var1
output_var += coupling
coupling = g * delayed_coupling_var
target_var += coupling

Parameters
----------
coupling_var: Variable
The coupling variable, used for delay.
target_var: Variable, sequence of Variable
var_to_output: Variable, sequence of Variable
The target variables to output.
conn_mat: JaxArray, ndarray
The connection matrix.
@@ -240,11 +251,12 @@ class AdditiveCoupling(DelayCoupling):
def __init__(
self,
coupling_var: bm.Variable,
target_var: Union[bm.Variable, Sequence[bm.Variable]],
conn_mat: Tensor,
delay_steps: Optional[Union[int, Tensor, Initializer, Callable]] = None,
initial_delay_data: Union[Initializer, Callable, Tensor, float, int, bool] = None,
name: str = None
var_to_output: Union[bm.Variable, Sequence[bm.Variable]],
conn_mat: Array,
delay_steps: Optional[Union[int, Array, Initializer, Callable]] = None,
initial_delay_data: Union[Initializer, Callable, Array, float, int, bool] = None,
name: str = None,
mode: Mode = normal,
):
if not isinstance(coupling_var, bm.Variable):
raise ValueError(f'"coupling_var" must be an instance of brainpy.math.Variable. '
@@ -255,31 +267,37 @@ class AdditiveCoupling(DelayCoupling):

super(AdditiveCoupling, self).__init__(
delay_var=coupling_var,
target_var=target_var,
var_to_output=var_to_output,
conn_mat=conn_mat,
required_shape=(coupling_var.size, coupling_var.size),
delay_steps=delay_steps,
initial_delay_data=initial_delay_data,
name=name
name=name,
mode=mode,
)

self.coupling_var = coupling_var

def update(self, t, dt):
# delay variable
delay_var: bm.LengthDelay = self.global_delay_vars[f'delay_{id(self.delay_var)}']

def update(self, tdi):
# delay function
axis = self.coupling_var.ndim
delay_var: bm.LengthDelay = self.global_delay_data[f'delay_{id(self.delay_var)}'][0]
if self.delay_steps is None:
additive = self.coupling_var @ self.conn_mat
elif self.delay_type == 'array':
if isinstance(self.mode, TrainingMode):
indices = (slice(None, None, None), bm.arange(self.coupling_var.size),)
else:
indices = (bm.arange(self.coupling_var.size),)
f = vmap(lambda steps: delay_var(steps, *indices), in_axes=1) # (.., pre.num,)
delays = f(self.delay_steps) # (..., post.num, pre.num)
additive = (self.conn_mat * bm.moveaxis(delays, axis - 1, axis)).sum(axis=axis - 1)
elif self.delay_type == 'int':
delayed_var = delay_var(self.delay_steps) # (..., pre.num)
additive = delayed_var @ self.conn_mat
else:
f = vmap(lambda i: delay_var(self.delay_steps[i], bm.arange(self.coupling_var.size))) # (pre.num,)
delays = f(bm.arange(self.coupling_var.size).value) # (post.num, pre.num)
additive = (self.conn_mat * delays.T).sum(axis=0)
raise ValueError

# output to target variable
for target in self.output_var:
target.value += additive

# update
delay_var.update(self.delay_var)

+ 62
- 0
brainpy/dyn/synapses/gap_junction.py View File

@@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-

from typing import Union, Dict, Callable, Optional

import brainpy.math as bm
from brainpy.connect import TwoEndConnector
from brainpy.dyn.base import NeuGroup, SynOut, SynSTP, TwoEndConn
from brainpy.initialize import Initializer, parameter
from brainpy.types import Array
from ..synouts import CUBA

__all__ = [
'GapJunction',
]


class GapJunction(TwoEndConn):
def __init__(
self,
pre: NeuGroup,
post: NeuGroup,
conn: Union[TwoEndConnector, Array, Dict[str, Array]],
comp_method: str = 'dense',
g_max: Union[float, Array, Initializer, Callable] = 1.,
name: str = None,
):
super(GapJunction, self).__init__(pre=pre,
post=post,
conn=conn,
name=name)
# checking
self.check_pre_attrs('V', 'spike')
self.check_post_attrs('V', 'input', 'spike')

# assert isinstance(self.output, _NullSynOut)
# assert isinstance(self.stp, _NullSynSTP)

# connections
self.comp_method = comp_method
if comp_method == 'dense':
self.conn_mat = self.conn.require('conn_mat')
self.weights = parameter(g_max, (pre.num, post.num), allow_none=False)
elif comp_method == 'sparse':
self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids')
self.weights = parameter(g_max, self.pre_ids.shape, allow_none=False)
else:
raise ValueError

def update(self, tdi):
if self.comp_method == 'dense':
# pre -> post
diff = (self.pre.V.reshape((-1, 1)) - self.post.V) * self.conn_mat * self.weights
self.post.input += bm.einsum('ij->j', diff)
# post -> pre
self.pre.input += bm.einsum('ij->i', -diff)
else:
diff = (self.pre.V[self.pre_ids] - self.post.V[self.post_ids]) * self.weights
self.post.input += bm.syn2post_sum(diff, self.post_ids, self.post.num)
self.pre.input += bm.syn2post_sum(-diff, self.pre_ids, self.pre.num)

def reset_state(self, batch_size=None):
pass

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save