#29 version 2.1.2

Merged
BrainPy merged 116 commits from openi into master 2 years ago
  1. +0
    -10
      .github/ISSUE_TEMPLATE/Feature_request.md
  2. +1
    -1
      .github/ISSUE_TEMPLATE/bug_report.md
  3. +1
    -1
      .github/ISSUE_TEMPLATE/config.yml
  4. +3
    -3
      .github/workflows/Linux_CI.yml
  5. +5
    -6
      .github/workflows/MacOS_CI.yml
  6. +18
    -0
      .github/workflows/Sync_branches.yml
  7. +3
    -3
      .github/workflows/Windows_CI.yml
  8. +2
    -0
      .github/workflows/contributors.yml
  9. +1
    -0
      .gitignore
  10. +154
    -12
      README.md
  11. +0
    -159
      README2.md
  12. +6
    -4
      brainpy/__init__.py
  13. +39
    -21
      brainpy/analysis/highdim/slow_points.py
  14. +26
    -21
      brainpy/analysis/lowdim/lowdim_analyzer.py
  15. +74
    -60
      brainpy/analysis/lowdim/lowdim_bifurcation.py
  16. +67
    -49
      brainpy/analysis/lowdim/lowdim_phase_plane.py
  17. +4
    -2
      brainpy/analysis/utils/measurement.py
  18. +1
    -2
      brainpy/analysis/utils/model.py
  19. +4
    -4
      brainpy/analysis/utils/optimization.py
  20. +2
    -1
      brainpy/analysis/utils/others.py
  21. +27
    -0
      brainpy/check.py
  22. +0
    -12
      brainpy/compact/models.py
  23. +0
    -11
      brainpy/compact/runners.py
  24. +0
    -0
      brainpy/compat/__init__.py
  25. +30
    -0
      brainpy/compat/brainobjects.py
  26. +20
    -0
      brainpy/compat/integrators.py
  27. +5
    -2
      brainpy/compat/layers.py
  28. +98
    -0
      brainpy/compat/models.py
  29. +7
    -2
      brainpy/compat/monitor.py
  30. +65
    -0
      brainpy/compat/runners.py
  31. +2
    -2
      brainpy/connect/tests/test_regular_conn.py
  32. +2
    -2
      brainpy/datasets/chaotic_systems.py
  33. +0
    -752
      brainpy/dyn/neurons/IF_models.py
  34. +2
    -1
      brainpy/dyn/neurons/__init__.py
  35. +58
    -13
      brainpy/dyn/neurons/biological_models.py
  36. +294
    -0
      brainpy/dyn/neurons/fractional_models.py
  37. +72
    -0
      brainpy/dyn/neurons/noise_models.py
  38. +542
    -180
      brainpy/dyn/neurons/rate_models.py
  39. +1024
    -12
      brainpy/dyn/neurons/reduced_models.py
  40. +27
    -23
      brainpy/dyn/runners.py
  41. +0
    -3
      brainpy/dyn/runners/__init__.py
  42. +2
    -0
      brainpy/dyn/synapses/__init__.py
  43. +24
    -5
      brainpy/dyn/synapses/abstract_models.py
  44. +206
    -0
      brainpy/dyn/synapses/delay_coupling.py
  45. +2
    -1
      brainpy/errors.py
  46. +1
    -0
      brainpy/initialize/__init__.py
  47. +3
    -3
      brainpy/initialize/base.py
  48. +46
    -0
      brainpy/initialize/generic.py
  49. +7
    -8
      brainpy/initialize/random_inits.py
  50. +1
    -199
      brainpy/inputs/__init__.py
  51. +386
    -0
      brainpy/inputs/currents.py
  52. +10
    -1
      brainpy/integrators/__init__.py
  53. +8
    -2
      brainpy/integrators/base.py
  54. +17
    -5
      brainpy/integrators/dde/base.py
  55. +37
    -2
      brainpy/integrators/dde/explicit_rk.py
  56. +1
    -15
      brainpy/integrators/dde/generic.py
  57. +401
    -0
      brainpy/integrators/fde/Caputo.py
  58. +190
    -0
      brainpy/integrators/fde/GL.py
  59. +0
    -95
      brainpy/integrators/fde/RL.py
  60. +7
    -0
      brainpy/integrators/fde/__init__.py
  61. +76
    -2
      brainpy/integrators/fde/base.py
  62. +92
    -0
      brainpy/integrators/fde/generic.py
  63. +33
    -0
      brainpy/integrators/fde/tests/test_Caputo.py
  64. +32
    -0
      brainpy/integrators/fde/tests/test_GL.py
  65. +0
    -16
      brainpy/integrators/fde/tests/test_RL.py
  66. +1
    -1
      brainpy/integrators/joint_eq.py
  67. +39
    -0
      brainpy/integrators/ode/adaptive_rk.py
  68. +12
    -1
      brainpy/integrators/ode/base.py
  69. +37
    -0
      brainpy/integrators/ode/explicit_rk.py
  70. +9
    -0
      brainpy/integrators/ode/exponential.py
  71. +1
    -29
      brainpy/integrators/ode/generic.py
  72. +1
    -1
      brainpy/integrators/runner.py
  73. +1
    -12
      brainpy/integrators/sde/generic.py
  74. +14
    -0
      brainpy/integrators/sde/normal.py
  75. +11
    -1
      brainpy/integrators/sde/srk_scalar.py
  76. +37
    -0
      brainpy/integrators/utils.py
  77. +10
    -10
      brainpy/losses/__init__.py
  78. +2
    -3
      brainpy/math/__init__.py
  79. +2
    -3
      brainpy/math/autograd.py
  80. +3
    -1
      brainpy/math/compat/__init__.py
  81. +45
    -0
      brainpy/math/compat/delay_vars.py
  82. +40
    -0
      brainpy/math/compat/losses.py
  83. +60
    -0
      brainpy/math/compat/optimizers.py
  84. +183
    -107
      brainpy/math/delay_vars.py
  85. +8
    -8
      brainpy/math/numpy_ops.py
  86. +32
    -24
      brainpy/math/parallels.py
  87. +0
    -71
      brainpy/math/special.py
  88. +78
    -23
      brainpy/math/tests/test_delay_vars.py
  89. +2
    -203
      brainpy/measure/__init__.py
  90. +270
    -0
      brainpy/measure/correlation.py
  91. +76
    -0
      brainpy/measure/firings.py
  92. +59
    -0
      brainpy/measure/tests/test_correlation.py
  93. +22
    -0
      brainpy/measure/tests/test_firings.py
  94. +314
    -219
      brainpy/nn/base.py
  95. +3
    -4
      brainpy/nn/nodes/ANN/conv.py
  96. +3
    -3
      brainpy/nn/nodes/ANN/dropout.py
  97. +94
    -64
      brainpy/nn/nodes/ANN/rnn_cells.py
  98. +6
    -6
      brainpy/nn/nodes/RC/linear_readout.py
  99. +44
    -32
      brainpy/nn/nodes/RC/nvar.py
  100. +5
    -7
      brainpy/nn/nodes/RC/reservoir.py

+ 0
- 10
.github/ISSUE_TEMPLATE/Feature_request.md View File

@@ -1,10 +0,0 @@
---
name: 'Feature Request'
about: 'Suggest a new idea or improvement for Brainpy'
labels: 'enhancement'
---

Please:

- [ ] Check for duplicate requests.
- [ ] Describe your goal, and if possible provide a code snippet with a motivating example.

+ 1
- 1
.github/ISSUE_TEMPLATE/bug_report.md View File

@@ -1,5 +1,5 @@
---
name: 'Bug report'
name: 'Bug Report'
about: 'Report a bug to help improve the package'
labels: 'bug'
---


+ 1
- 1
.github/ISSUE_TEMPLATE/config.yml View File

@@ -1,5 +1,5 @@
blank_issues_enabled: false
contact_links:
- name: Question
url: https://github.com/google/jax/discussions
url: https://github.com/PKU-NIP-Lab/BrainPy/discussions
about: Please ask questions on the Discussions tab

+ 3
- 3
.github/workflows/Linux_CI.yml View File

@@ -5,9 +5,9 @@ name: Linux CI

on:
push:
branches: [ master, brainpy-2.x, V2.1.0 ]
branches: [ master ]
pull_request:
branches: [ master, brainpy-2.x, V2.1.0 ]
branches: [ master ]


jobs:
@@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.7", "3.8", "3.9"]
python-version: ["3.7", "3.8", "3.9", "3.10"]

steps:
- uses: actions/checkout@v2


+ 5
- 6
.github/workflows/MacOS_CI.yml View File

@@ -5,19 +5,18 @@ name: MacOS CI

on:
push:
branches: [ master, brainpy-2.x, V2.1.0 ]
branches: [ master ]
pull_request:
branches: [ master, brainpy-2.x, V2.1.0 ]
branches: [ master ]


jobs:
build:
runs-on: ${{ matrix.os }}
runs-on: macos-latest
strategy:
fail-fast: false
matrix:
os: [macos-10.15, macos-11, macos-latest]
python-version: ["3.7", "3.8", "3.9"]
python-version: ["3.7", "3.8", "3.9", "3.10"]

steps:
- uses: actions/checkout@v2
@@ -39,4 +38,4 @@ jobs:
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest
pytest brainpy/

+ 18
- 0
.github/workflows/Sync_branches.yml View File

@@ -0,0 +1,18 @@
name: Sync multiple branches
on:
pull_request:
branches:
- master
jobs:
sync-branch:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@master

- name: Merge master -> brainpy-2.x
uses: devmasx/merge-branch@v1.3.1
with:
type: now
from_branch: master
target_branch: brainpy-2.x
github_token: ${{ github.token }}

+ 3
- 3
.github/workflows/Windows_CI.yml View File

@@ -5,9 +5,9 @@ name: Windows CI

on:
push:
branches: [ master, brainpy-2.x, V2.1.0 ]
branches: [ master ]
pull_request:
branches: [ master, brainpy-2.x, V2.1.0 ]
branches: [ master ]


jobs:
@@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.7", "3.8", "3.9"]
python-version: ["3.7", "3.8", "3.9", "3.10"]

steps:
- uses: actions/checkout@v2


+ 2
- 0
.github/workflows/contributors.yml View File

@@ -2,6 +2,8 @@ name: Add contributors
on:
schedule:
- cron: '20 20 * * *'
push:
branches: [ master ]

jobs:
add-contributors:


+ 1
- 0
.gitignore View File

@@ -17,6 +17,7 @@ BrainModels/
book/
docs/examples
docs/apis/jaxsetting.rst
docs/quickstart/data
examples/recurrent_neural_network/neurogym
develop/iconip_paper
develop/benchmark/COBA/results


+ 154
- 12
README.md View File

@@ -28,9 +28,9 @@ BrainPy is a flexible, efficient, and extensible framework for computational neu



## Install
## Installation

BrainPy is based on Python (>=3.6) and can be installed on Linux (Ubuntu 16.04 or later), macOS (10.12 or later), and Windows platforms. Install the latest version of BrainPy:
BrainPy is based on Python (>=3.7) and can be installed on Linux (Ubuntu 16.04 or later), macOS (10.12 or later), and Windows platforms. Install the latest version of BrainPy:

```bash
$ pip install brain-py -U
@@ -54,7 +54,121 @@ import brainpy as bp



**1\. E-I balance network**
### 1. Operator level

Mathematical operators in BrainPy are the same as those in NumPy.

```python
>>> import numpy as np
>>> import brainpy.math as bm

# array creation
>>> np_arr = np.zeros((2, 4)); np_arr
array([[0., 0., 0., 0.],
[0., 0., 0., 0.]])
>>> bm_arr = bm.zeros((2, 4)); bm_arr
JaxArray([[0., 0., 0., 0.],
[0., 0., 0., 0.]], dtype=float32)

# in-place updating
>>> np_arr[0] += 1.; np_arr
array([[1., 1., 1., 1.],
[0., 0., 0., 0.]])
>>> bm_arr[0] += 1.; bm_arr
JaxArray([[1., 1., 1., 1.],
[0., 0., 0., 0.]], dtype=float32)

# mathematical functions
>>> np.sin(np_arr)
array([[0.84147098, 0.84147098, 0.84147098, 0.84147098],
[0. , 0. , 0. , 0. ]])
>>> bm.sin(bm_arr)
JaxArray([[0.84147096, 0.84147096, 0.84147096, 0.84147096],
[0. , 0. , 0. , 0. ]], dtype=float32)

# linear algebra
>>> np.dot(np_arr, np.ones((4, 2)))
array([[4., 4.],
[0., 0.]])
>>> bm.dot(bm_arr, bm.ones((4, 2)))
JaxArray([[4., 4.],
[0., 0.]], dtype=float32)

# random number generation
>>> np.random.uniform(-0.1, 0.1, (2, 3))
array([[-0.02773637, 0.03766689, -0.01363128],
[-0.01946991, -0.06669802, 0.09426067]])
>>> bm.random.uniform(-0.1, 0.1, (2, 3))
JaxArray([[-0.03044081, -0.07787752, 0.04346445],
[-0.01366713, -0.0522548 , 0.04372055]], dtype=float32)
```



### 2. Integrator level

Numerical methods for ordinary differential equations (ODEs).

```python
sigma = 10; beta = 8/3; rho = 28

@bp.odeint(method='rk4')
def lorenz_system(x, y, z, t):
dx = sigma * (y - x)
dy = x * (rho - z) - y
dz = x * y - beta * z
return dx, dy, dz

runner = bp.integrators.IntegratorRunner(lorenz_system, dt=0.01)
runner.run(100.)
```



Numerical methods for stochastic differential equations (SDEs).

```python
sigma = 10; beta = 8/3; rho = 28
p=0.1

def lorenz_noise(x, y, z, t):
return p*x, p*y, p*z

@bp.odeint(method='milstein', g=lorenz_noise)
def lorenz_system(x, y, z, t):
dx = sigma * (y - x)
dy = x * (rho - z) - y
dz = x * y - beta * z
return dx, dy, dz

runner = bp.integrators.IntegratorRunner(lorenz_system, dt=0.01)
runner.run(100.)
```



Numerical methods for delay differential equations (SDEs).

```python
xdelay = bm.TimeDelay(bm.zeros(1), delay_len=1., before_t0=1., dt=0.01)


@bp.ddeint(method='rk4', state_delays={'x': xdelay})
def second_order_eq(x, y, t):
dx = y
dy = -y - 2 * x - 0.5 * xdelay(t - 1)
return dx, dy


runner = bp.integrators.IntegratorRunner(second_order_eq, dt=0.01)
runner.run(100.)
```



### 3. Dynamics simulation level

Building an E-I balance network.

```python
class EINet(bp.dyn.Network):
@@ -77,9 +191,36 @@ runner = bp.dyn.DSRunner(net)
runner(100.)
```

Simulating a whole brain network by using rate models.

```python
import numpy as np

class WholeBrainNet(bp.dyn.Network):
def __init__(self, signal_speed=20.):
super(WholeBrainNet, self).__init__()

**2\. Echo state network**
self.fhn = bp.dyn.RateFHN(80, x_ou_sigma=0.01, y_ou_sigma=0.01, name='fhn')
self.syn = bp.dyn.DiffusiveDelayCoupling(self.fhn, self.fhn,
'x->input',
conn_mat=conn_mat,
delay_mat=delay_mat)

def update(self, _t, _dt):
self.syn.update(_t, _dt)
self.fhn.update(_t, _dt)


net = WholeBrainNet()
runner = bp.dyn.DSRunner(net, monitors=['fhn.x'], inputs=['fhn.input', 0.72])
runner.run(6e3)
```



### 4. Dynamics training level

Training an echo state network.

```python
i = bp.nn.Input(3)
@@ -88,16 +229,14 @@ o = bp.nn.LinearReadout(3)

net = i >> r >> o

# Ridge Regression
trainer = bp.nn.RidgeTrainer(net, beta=1e-5)
trainer = bp.nn.RidgeTrainer(net, beta=1e-5) # Ridge Regression

# FORCE Learning
trainer = bp.nn.FORCELearning(net, alpha=1.)
trainer = bp.nn.FORCELearning(net, alpha=1.) # FORCE Learning
```



**3. Next generation reservoir computing**
Training a next-generation reservoir computing model.

```python
i = bp.nn.Input(3)
@@ -111,7 +250,7 @@ trainer = bp.nn.RidgeTrainer(net, beta=1e-5)



**4. Recurrent neural network**
Training an artificial recurrent neural network.

```python
i = bp.nn.Input(3)
@@ -128,7 +267,9 @@ trainer = bp.nn.BPTT(net,



**5\. Analyzing a low-dimensional FitzHugh–Nagumo neuron model**
### 5. Dynamics analysis level

Analyzing a low-dimensional FitzHugh–Nagumo neuron model.

```python
bp.math.enable_x64()
@@ -149,9 +290,10 @@ analyzer.show_figure()
</p>


For **more functions and examples**, please refer to the [documentation](https://brainpy.readthedocs.io/) and [examples](https://brainpy-examples.readthedocs.io/).

### 6. More others

For **more functions and examples**, please refer to the [documentation](https://brainpy.readthedocs.io/) and [examples](https://brainpy-examples.readthedocs.io/).


## License


+ 0
- 159
README2.md View File

@@ -1,159 +0,0 @@
<p align="center">
<img alt="Header image of BrainPy - brain dynamics programming in Python." src="./images/logo.png" >
</p>


<p align="center">
<a href="https://pypi.org/project/brain-py/"><img alt="Supported Python Version" src="https://img.shields.io/pypi/pyversions/brain-py"></a>
<a href="https://github.com/PKU-NIP-Lab/BrainPy"><img alt="LICENSE" src="https://anaconda.org/brainpy/brainpy/badges/license.svg"></a>
<a href="https://brainpy.readthedocs.io/en/latest/?badge=latest"><img alt="Documentation" src="https://readthedocs.org/projects/brainpy/badge/?version=latest"></a>
<a href="https://badge.fury.io/py/brain-py"><img alt="PyPI version" src="https://badge.fury.io/py/brain-py.svg"></a>
<a href="https://github.com/PKU-NIP-Lab/BrainPy"><img alt="Linux CI" src="https://github.com/PKU-NIP-Lab/BrainPy/actions/workflows/Linux_CI.yml/badge.svg"></a>
<a href="https://github.com/PKU-NIP-Lab/BrainPy"><img alt="Linux CI" src="https://github.com/PKU-NIP-Lab/BrainPy/actions/workflows/Windows_CI.yml/badge.svg"></a>
</p>


:clap::clap: **CHEERS**: A new version of BrainPy (>=2.0.0, long term support) has been released! :clap::clap:



# Why use BrainPy

``BrainPy`` is an integrative framework for computational neuroscience and brain-inspired computation based on the Just-In-Time (JIT) compilation (built on top of [JAX](https://github.com/google/jax)). Core functions provided in BrainPy includes

- **JIT compilation** for class objects.
- **Numerical solvers** for ODEs, SDEs, and others.
- **Dynamics simulation tools** for various brain objects, like neurons, synapses, networks, soma, dendrites, channels, and even more.
- **Dynamics analysis tools** for differential equations, including phase plane analysis and bifurcation analysis, and linearization analysis.
- **Seamless integration with deep learning models**.
- And more ......

`BrainPy` is designed to effectively satisfy your basic requirements:

- **Pythonic**: BrainPy is based on Python language and has a Pythonic coding style.
- **Flexible and transparent**: BrainPy endows the users with full data/logic flow control. Users can code any logic they want with BrainPy.
- **Extensible**: BrainPy allows users to extend new functionality just based on Python code. Almost every part of the BrainPy system can be extended to be customized.
- **Efficient**: All codes in BrainPy can be just-in-time compiled (based on [JAX](https://github.com/google/jax)) to run on CPU, GPU, or TPU devices, thus guaranteeing its running efficiency.



# How to use BrainPy

## Step 1: installation

``BrainPy`` is based on Python (>=3.6), and the following packages are required to be installed to use ``BrainPy``: `numpy >= 1.15`, `matplotlib >= 3.4`, and `jax >= 0.2.10` ([how to install jax?](https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#dependency-2-jax))

``BrainPy`` can be installed on Linux (Ubuntu 16.04 or later), macOS (10.12 or later), and Windows platforms. Use the following instructions to install ``brainpy``:

```bash
pip install brain-py -U
```

*For the full installation details please see documentation: [Quickstart/Installation](https://brainpy.readthedocs.io/en/latest/quickstart/installation.html)*




## Step 2: useful links

- **Documentation:** https://brainpy.readthedocs.io/
- **Bug reports:** https://github.com/PKU-NIP-Lab/BrainPy/issues
- **Examples from papers**: https://brainpy-examples.readthedocs.io/
- **Canonical brain models**: https://brainmodels.readthedocs.io/



## Step 3: inspirational examples

Here we list several examples of BrainPy. For more detailed examples and tutorials please see [**BrainModels**](https://brainmodels.readthedocs.io) or [**BrainPy-Examples**](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/).



### Neuron models

- [Leaky integrate-and-fire neuron model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.neurons.LIF.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/neurons/LIF.py)
- [Exponential integrate-and-fire neuron model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.neurons.ExpIF.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/neurons/ExpIF.py)
- [Quadratic integrate-and-fire neuron model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.neurons.QuaIF.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/neurons/QuaIF.py)
- [Adaptive Quadratic integrate-and-fire model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.neurons.AdQuaIF.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/neurons/AdQuaIF.py)
- [Adaptive Exponential integrate-and-fire model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.neurons.AdExIF.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/neurons/AdExIF.py)
- [Generalized integrate-and-fire model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.neurons.GIF.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/neurons/GIF.py)
- [Hodgkin–Huxley neuron model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.neurons.HH.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/neurons/HH.py)
- [Izhikevich neuron model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.neurons.Izhikevich.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/neurons/Izhikevich.py)
- [Morris-Lecar neuron model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.neurons.MorrisLecar.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/neurons/MorrisLecar.py)
- [Hindmarsh-Rose bursting neuron model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.neurons.HindmarshRose.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/neurons/HindmarshRose.py)

See [brainmodels.neurons](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/neurons.html) to find more.



### Synapse models

- [Voltage jump synapse model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.synapses.VoltageJump.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/synapses/voltage_jump.py)
- [Exponential synapse model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.synapses.ExpCUBA.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/synapses/exponential.py)
- [Alpha synapse model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.synapses.AlphaCUBA.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/synapses/alpha.py)
- [Dual exponential synapse model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.synapses.DualExpCUBA.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/synapses/dual_exp.py)
- [AMPA synapse model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.synapses.AMPA.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/synapses/AMPA.py)
- [GABAA synapse model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.synapses.GABAa.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/synapses/GABAa.py)
- [NMDA synapse model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.synapses.NMDA.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/synapses/NMDA.py)
- [Short-term plasticity model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.synapses.STP.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/synapses/STP.py)

See [brainmodels.synapses](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/synapses.html) to find more.



### Network models

- **[CANN]** [*(Si Wu, 2008)* Continuous-attractor Neural Network](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/cann/Wu_2008_CANN.html)
- [*(Vreeswijk & Sompolinsky, 1996)* E/I balanced network](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/ei_nets/Vreeswijk_1996_EI_net.html)
- [*(Sherman & Rinzel, 1992)* Gap junction leads to anti-synchronization](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/gj_nets/Sherman_1992_gj_antisynchrony.html)
- [*(Wang & Buzsáki, 1996)* Gamma Oscillation](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/oscillation_synchronization/Wang_1996_gamma_oscillation.html)
- [*(Brunel & Hakim, 1999)* Fast Global Oscillation](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/oscillation_synchronization/Brunel_Hakim_1999_fast_oscillation.html)
- [*(Diesmann, et, al., 1999)* Synfire Chains](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/oscillation_synchronization/Diesmann_1999_synfire_chains.html)
- **[Working Memory]** [*(Mi, et. al., 2017)* STP for Working Memory Capacity](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/working_memory/Mi_2017_working_memory_capacity.html)
- **[Working Memory]** [*(Bouchacourt & Buschman, 2019)* Flexible Working Memory Model](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/working_memory/Bouchacourt_2019_Flexible_working_memory.html)
- **[Decision Making]** [*(Wang, 2002)* Decision making spiking model](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/decision_making/Wang_2002_decision_making_spiking.html)



### Dynamics training

- [Train Integrator RNN with BP](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/recurrent_networks/integrator_rnn.html)

- [*(Sussillo & Abbott, 2009)* FORCE Learning](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/recurrent_networks/Sussillo_Abbott_2009_FORCE_Learning.html)

- [*(Laje & Buonomano, 2013)* Robust Timing in RNN](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/recurrent_networks/Laje_Buonomano_2013_robust_timing_rnn.html)
- [*(Song, et al., 2016)*: Training excitatory-inhibitory recurrent network](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/recurrent_networks/Song_2016_EI_RNN.html)
- **[Working Memory]** [*(Masse, et al., 2019)*: RNN with STP for Working Memory](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/recurrent_networks/Masse_2019_STP_RNN.html)




### Low-dimensional dynamics analysis

- [[1D] Simple systems](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/dynamics_analysis/1d_simple_systems.html)
- [[2D] NaK model analysis](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/dynamics_analysis/2d_NaK_model.html)
- [[3D] Hindmarsh Rose Model](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/dynamics_analysis/3d_hindmarsh_rose_model.html)
- **[Decision Making Model]** [[2D] Decision making rate model](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/decision_making/Wang_2006_decision_making_rate.html)



### High-dimensional dynamics analysis

- [*(Yang, 2020)*: Dynamical system analysis for RNN](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/recurrent_networks/Yang_2020_RNN_Analysis.html)
- [Continuous-attractor Neural Network](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/dynamics_analysis/highdim_CANN.html)
- [Gap junction-coupled FitzHugh-Nagumo Model](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/dynamics_analysis/highdim_gj_coupled_fhn.html)



# BrainPy 1.x

If you are using ``brainpy==1.x``, you can find *documentation*, *examples*, and *models* through the following links:

- **Documentation:** https://brainpy.readthedocs.io/en/brainpy-1.x/
- **Examples from papers**: https://brainpy-examples.readthedocs.io/en/brainpy-1.x/
- **Canonical brain models**: https://brainmodels.readthedocs.io/en/brainpy-1.x/

The changes from ``brainpy==1.x`` to ``brainpy==2.x`` can be inspected through [API documentation: release notes](https://brainpy.readthedocs.io/en/latest/apis/auto/changelog.html).


# Contributors

+ 6
- 4
brainpy/__init__.py View File

@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "2.1.0"
__version__ = "2.1.2"


try:
@@ -15,7 +15,7 @@ except ModuleNotFoundError:


# fundamental modules
from . import errors, tools
from . import errors, tools, check


# "base" module
@@ -37,9 +37,11 @@ from . import integrators
from .integrators import ode
from .integrators import sde
from .integrators import dde
from .integrators import fde
from .integrators.ode import odeint
from .integrators.sde import sdeint
from .integrators.dde import ddeint
from .integrators.fde import fdeint
from .integrators.joint_eq import JointEq


@@ -59,12 +61,12 @@ from . import running
from . import analysis


# "visualization" module, will be remove soon
# "visualization" module, will be removed soon
from .visualization import visualize


# compatible interface
from .compact import * # compact
from .compat import * # compat


# convenient access


+ 39
- 21
brainpy/analysis/highdim/slow_points.py View File

@@ -1,17 +1,18 @@
# -*- coding: utf-8 -*-

import inspect
import time
import warnings
from functools import partial

from jax import vmap
import jax.numpy
import numpy as np
from jax.scipy.optimize import minimize

import brainpy.math as bm
from brainpy import optimizers as optim
from brainpy.analysis import utils
from brainpy.errors import AnalyzerError
from brainpy import optimizers as optim

__all__ = [
'SlowPointFinder',
@@ -56,15 +57,15 @@ class SlowPointFinder(object):
if f_loss_batch is None:
if f_type == 'discrete':
self.f_loss = bm.jit(lambda h: bm.mean((h - f_cell(h)) ** 2))
self.f_loss_batch = bm.jit(lambda h: bm.mean((h - bm.vmap(f_cell, auto_infer=False)(h)) ** 2, axis=1))
self.f_loss_batch = bm.jit(lambda h: bm.mean((h - vmap(f_cell)(h)) ** 2, axis=1))
if f_type == 'continuous':
self.f_loss = bm.jit(lambda h: bm.mean(f_cell(h) ** 2))
self.f_loss_batch = bm.jit(lambda h: bm.mean((bm.vmap(f_cell, auto_infer=False)(h)) ** 2, axis=1))
self.f_loss_batch = bm.jit(lambda h: bm.mean((vmap(f_cell)(h)) ** 2, axis=1))

else:
self.f_loss_batch = f_loss_batch
self.f_loss = bm.jit(lambda h: bm.mean(f_cell(h) ** 2))
self.f_jacob_batch = bm.jit(bm.vmap(bm.jacobian(f_cell)))
self.f_jacob_batch = bm.jit(vmap(bm.jacobian(f_cell)))

# essential variables
self._losses = None
@@ -87,8 +88,13 @@ class SlowPointFinder(object):
"""The selected ids of candidate points."""
return self._selected_ids

def find_fps_with_gd_method(self, candidates, tolerance=1e-5, num_batch=100,
num_opt=10000, opt_setting=None):
def find_fps_with_gd_method(self,
candidates,
tolerance=1e-5,
num_batch=100,
num_opt=10000,
optimizer=None,
opt_setting=None):
"""Optimize fixed points with gradient descent methods.

Parameters
@@ -104,17 +110,30 @@ class SlowPointFinder(object):
Print training information during optimization every so often.
opt_setting: optional, dict
The optimization settings.

.. deprecated:: 2.1.2
Use "optimizer" to set optimization method instead.

optimizer: optim.Optimizer
The optimizer instance.

.. versionadded:: 2.1.2
"""

# optimization settings
if opt_setting is None:
opt_method = optim.Adam
opt_lr = optim.ExponentialDecay(0.2, 1, 0.9999)
opt_setting = {'beta1': 0.9,
'beta2': 0.999,
'eps': 1e-8,
'name': None}
if optimizer is None:
optimizer = optim.Adam(lr=optim.ExponentialDecay(0.2, 1, 0.9999),
beta1=0.9, beta2=0.999, eps=1e-8)
else:
assert isinstance(optimizer, optim.Optimizer), (f'Must be an instance of '
f'{optim.Optimizer.__name__}, '
f'while we got {type(optimizer)}')
else:
warnings.warn('Please use "optimizer" to set optimization method. '
'"opt_setting" is deprecated since version 2.1.2. ',
DeprecationWarning)

assert isinstance(opt_setting, dict)
assert 'method' in opt_setting
assert 'lr' in opt_setting
@@ -122,26 +141,25 @@ class SlowPointFinder(object):
if isinstance(opt_method, str):
assert opt_method in optim.__dict__
opt_method = getattr(optim, opt_method)
assert isinstance(opt_method, type)
if optim.Optimizer not in inspect.getmro(opt_method):
raise ValueError
assert issubclass(opt_method, optim.Optimizer)
opt_lr = opt_setting.pop('lr')
assert isinstance(opt_lr, (int, float, optim.Scheduler))
opt_setting = opt_setting
optimizer = opt_method(lr=opt_lr, **opt_setting)

if self.verbose:
print(f"Optimizing with {opt_method.__name__} to find fixed points:")
print(f"Optimizing with {optimizer.__name__} to find fixed points:")

# set up optimization
fixed_points = bm.Variable(bm.asarray(candidates))
grad_f = bm.grad(lambda: self.f_loss_batch(fixed_points.value).mean(),
grad_vars={'a': fixed_points}, return_value=True)
opt = opt_method(train_vars={'a': fixed_points}, lr=opt_lr, **opt_setting)
dyn_vars = opt.vars() + {'_a': fixed_points}
optimizer.register_vars({'a': fixed_points})
dyn_vars = optimizer.vars() + {'_a': fixed_points}

def train(idx):
gradients, loss = grad_f()
opt.update(gradients)
optimizer.update(gradients)
return loss

@partial(bm.jit, dyn_vars=dyn_vars, static_argnames=('start_i', 'num_batch'))
@@ -191,7 +209,7 @@ class SlowPointFinder(object):
opt_method = lambda f, x0: minimize(f, x0, method='BFGS')
if self.verbose:
print(f"Optimizing to find fixed points:")
f_opt = bm.jit(bm.vmap(lambda x0: opt_method(self.f_loss, x0)))
f_opt = bm.jit(vmap(lambda x0: opt_method(self.f_loss, x0)))
res = f_opt(bm.as_device_array(candidates))
valid_ids = jax.numpy.where(res.success)[0]
self._fixed_points = np.asarray(res.x[valid_ids])


+ 26
- 21
brainpy/analysis/lowdim/lowdim_analyzer.py View File

@@ -2,8 +2,8 @@

from functools import partial

import matplotlib.pyplot as plt
import numpy as np
from jax import vmap
from jax import numpy as jnp
from jax.scipy.optimize import minimize

@@ -12,6 +12,8 @@ from brainpy import errors, tools
from brainpy.analysis import constants as C, utils
from brainpy.base.collector import Collector

pyplot = None

__all__ = [
'LowDimAnalyzer',
'Num1DAnalyzer',
@@ -207,7 +209,10 @@ class LowDimAnalyzer(object):
self.analyzed_results = tools.DictPlus()

def show_figure(self):
plt.show()
global pyplot
if pyplot is None:
from matplotlib import pyplot
pyplot.show()


class Num1DAnalyzer(LowDimAnalyzer):
@@ -258,7 +263,7 @@ class Num1DAnalyzer(LowDimAnalyzer):
@property
def F_vmap_fx(self):
if C.F_vmap_fx not in self.analyzed_results:
self.analyzed_results[C.F_vmap_fx] = bm.jit(bm.vmap(self.F_fx), device=self.jit_device)
self.analyzed_results[C.F_vmap_fx] = bm.jit(vmap(self.F_fx), device=self.jit_device)
return self.analyzed_results[C.F_vmap_fx]

@property
@@ -285,7 +290,7 @@ class Num1DAnalyzer(LowDimAnalyzer):
# ---
# "X": a two-dimensional matrix: (num_batch, num_var)
# "args": a list of one-dimensional vectors, each has the shape of (num_batch,)
self.analyzed_results[C.F_vmap_fp_aux] = bm.jit(bm.vmap(self.F_fixed_point_aux))
self.analyzed_results[C.F_vmap_fp_aux] = bm.jit(vmap(self.F_fixed_point_aux))
return self.analyzed_results[C.F_vmap_fp_aux]

@property
@@ -304,7 +309,7 @@ class Num1DAnalyzer(LowDimAnalyzer):
# ---
# "X": a two-dimensional matrix: (num_batch, num_var)
# "args": a list of one-dimensional vectors, each has the shape of (num_batch,)
self.analyzed_results[C.F_vmap_fp_opt] = bm.jit(bm.vmap(self.F_fixed_point_opt))
self.analyzed_results[C.F_vmap_fp_opt] = bm.jit(vmap(self.F_fixed_point_opt))
return self.analyzed_results[C.F_vmap_fp_opt]

def _get_fixed_points(self, candidates, *args, num_seg=None, tol_aux=1e-7, loss_screen=None):
@@ -497,7 +502,7 @@ class Num2DAnalyzer(Num1DAnalyzer):
@property
def F_vmap_fy(self):
if C.F_vmap_fy not in self.analyzed_results:
self.analyzed_results[C.F_vmap_fy] = bm.jit(bm.vmap(self.F_fy), device=self.jit_device)
self.analyzed_results[C.F_vmap_fy] = bm.jit(vmap(self.F_fy), device=self.jit_device)
return self.analyzed_results[C.F_vmap_fy]

@property
@@ -659,7 +664,7 @@ class Num2DAnalyzer(Num1DAnalyzer):

if self.F_x_by_y_in_fx is not None:
utils.output("I am evaluating fx-nullcline by F_x_by_y_in_fx ...")
vmap_f = bm.jit(bm.vmap(self.F_x_by_y_in_fx), device=self.jit_device)
vmap_f = bm.jit(vmap(self.F_x_by_y_in_fx), device=self.jit_device)
for j, pars in enumerate(par_seg):
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
mesh_values = jnp.meshgrid(*((ys,) + pars))
@@ -675,7 +680,7 @@ class Num2DAnalyzer(Num1DAnalyzer):

elif self.F_y_by_x_in_fx is not None:
utils.output("I am evaluating fx-nullcline by F_y_by_x_in_fx ...")
vmap_f = bm.jit(bm.vmap(self.F_y_by_x_in_fx), device=self.jit_device)
vmap_f = bm.jit(vmap(self.F_y_by_x_in_fx), device=self.jit_device)
for j, pars in enumerate(par_seg):
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
mesh_values = jnp.meshgrid(*((xs,) + pars))
@@ -693,9 +698,9 @@ class Num2DAnalyzer(Num1DAnalyzer):
utils.output("I am evaluating fx-nullcline by optimization ...")
# auxiliary functions
f2 = lambda y, x, *pars: self.F_fx(x, y, *pars)
vmap_f2 = bm.jit(bm.vmap(f2), device=self.jit_device)
vmap_brentq_f2 = bm.jit(bm.vmap(utils.jax_brentq(f2)), device=self.jit_device)
vmap_brentq_f1 = bm.jit(bm.vmap(utils.jax_brentq(self.F_fx)), device=self.jit_device)
vmap_f2 = bm.jit(vmap(f2), device=self.jit_device)
vmap_brentq_f2 = bm.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device)
vmap_brentq_f1 = bm.jit(vmap(utils.jax_brentq(self.F_fx)), device=self.jit_device)

# num segments
for _j, Ps in enumerate(par_seg):
@@ -752,7 +757,7 @@ class Num2DAnalyzer(Num1DAnalyzer):

if self.F_x_by_y_in_fy is not None:
utils.output("I am evaluating fy-nullcline by F_x_by_y_in_fy ...")
vmap_f = bm.jit(bm.vmap(self.F_x_by_y_in_fy), device=self.jit_device)
vmap_f = bm.jit(vmap(self.F_x_by_y_in_fy), device=self.jit_device)
for j, pars in enumerate(par_seg):
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
mesh_values = jnp.meshgrid(*((ys,) + pars))
@@ -768,7 +773,7 @@ class Num2DAnalyzer(Num1DAnalyzer):

elif self.F_y_by_x_in_fy is not None:
utils.output("I am evaluating fy-nullcline by F_y_by_x_in_fy ...")
vmap_f = bm.jit(bm.vmap(self.F_y_by_x_in_fy), device=self.jit_device)
vmap_f = bm.jit(vmap(self.F_y_by_x_in_fy), device=self.jit_device)
for j, pars in enumerate(par_seg):
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
mesh_values = jnp.meshgrid(*((xs,) + pars))
@@ -787,9 +792,9 @@ class Num2DAnalyzer(Num1DAnalyzer):

# auxiliary functions
f2 = lambda y, x, *pars: self.F_fy(x, y, *pars)
vmap_f2 = bm.jit(bm.vmap(f2), device=self.jit_device)
vmap_brentq_f2 = bm.jit(bm.vmap(utils.jax_brentq(f2)), device=self.jit_device)
vmap_brentq_f1 = bm.jit(bm.vmap(utils.jax_brentq(self.F_fy)), device=self.jit_device)
vmap_f2 = bm.jit(vmap(f2), device=self.jit_device)
vmap_brentq_f2 = bm.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device)
vmap_brentq_f1 = bm.jit(vmap(utils.jax_brentq(self.F_fy)), device=self.jit_device)

for j, Ps in enumerate(par_seg):
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
@@ -837,7 +842,7 @@ class Num2DAnalyzer(Num1DAnalyzer):
xs = self.resolutions[self.x_var].value
ys = self.resolutions[self.y_var].value
P = tuple(self.resolutions[p].value for p in self.target_par_names)
f_select = bm.jit(bm.vmap(lambda vals, ids: vals[ids], in_axes=(1, 1)))
f_select = bm.jit(vmap(lambda vals, ids: vals[ids], in_axes=(1, 1)))

# num seguments
if isinstance(num_segments, int):
@@ -917,10 +922,10 @@ class Num2DAnalyzer(Num1DAnalyzer):

if self.convert_type() == C.x_by_y:
num_seg = len(self.resolutions[self.y_var])
f_vmap = bm.jit(bm.vmap(self.F_y_convert[1]))
f_vmap = bm.jit(vmap(self.F_y_convert[1]))
else:
num_seg = len(self.resolutions[self.x_var])
f_vmap = bm.jit(bm.vmap(self.F_x_convert[1]))
f_vmap = bm.jit(vmap(self.F_x_convert[1]))
# get the signs
signs = jnp.sign(f_vmap(candidates, *args))
signs = signs.reshape((num_seg, -1))
@@ -950,10 +955,10 @@ class Num2DAnalyzer(Num1DAnalyzer):
# get another value
if self.convert_type() == C.x_by_y:
y_values = fps
x_values = bm.jit(bm.vmap(self.F_y_convert[0]))(y_values, *args)
x_values = bm.jit(vmap(self.F_y_convert[0]))(y_values, *args)
else:
x_values = fps
y_values = bm.jit(bm.vmap(self.F_x_convert[0]))(x_values, *args)
y_values = bm.jit(vmap(self.F_x_convert[0]))(x_values, *args)
fps = jnp.stack([x_values, y_values]).T
return fps, selected_ids, args



+ 74
- 60
brainpy/analysis/lowdim/lowdim_bifurcation.py View File

@@ -3,7 +3,7 @@
from functools import partial

import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax import vmap
import numpy as np

import brainpy.math as bm
@@ -11,6 +11,8 @@ from brainpy import errors
from brainpy.analysis import stability, utils, constants as C
from brainpy.analysis.lowdim.lowdim_analyzer import *

pyplot = None

__all__ = [
'Bifurcation1D',
'Bifurcation2D',
@@ -41,12 +43,14 @@ class Bifurcation1D(Num1DAnalyzer):
@property
def F_vmap_dfxdx(self):
if C.F_vmap_dfxdx not in self.analyzed_results:
f = bm.jit(bm.vmap(bm.vector_grad(self.F_fx, argnums=0)), device=self.jit_device)
f = bm.jit(vmap(bm.vector_grad(self.F_fx, argnums=0)), device=self.jit_device)
self.analyzed_results[C.F_vmap_dfxdx] = f
return self.analyzed_results[C.F_vmap_dfxdx]

def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
tol_aux=1e-8, loss_screen=None):
global pyplot
if pyplot is None: from matplotlib import pyplot
utils.output('I am making bifurcation analysis ...')

xs = self.resolutions[self.x_var]
@@ -72,21 +76,21 @@ class Bifurcation1D(Num1DAnalyzer):
container[fp_type]['x'].append(x)

# visualization
plt.figure(self.x_var)
pyplot.figure(self.x_var)
for fp_type, points in container.items():
if len(points['x']):
plot_style = stability.plot_scheme[fp_type]
plt.plot(points['p'], points['x'], '.', **plot_style, label=fp_type)
plt.xlabel(self.target_par_names[0])
plt.ylabel(self.x_var)
pyplot.plot(points['p'], points['x'], '.', **plot_style, label=fp_type)
pyplot.xlabel(self.target_par_names[0])
pyplot.ylabel(self.x_var)

scale = (self.lim_scale - 1) / 2
plt.xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale))
plt.ylim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
pyplot.xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale))
pyplot.ylim(*utils.rescale(self.target_vars[self.x_var], scale=scale))

plt.legend()
pyplot.legend()
if show:
plt.show()
pyplot.show()

elif len(self.target_pars) == 2:
container = {c: {'p0': [], 'p1': [], 'x': []} for c in stability.get_1d_stability_types()}
@@ -99,7 +103,7 @@ class Bifurcation1D(Num1DAnalyzer):
container[fp_type]['x'].append(x)

# visualization
fig = plt.figure(self.x_var)
fig = pyplot.figure(self.x_var)
ax = fig.add_subplot(projection='3d')
for fp_type, points in container.items():
if len(points['x']):
@@ -121,7 +125,7 @@ class Bifurcation1D(Num1DAnalyzer):
ax.grid(True)
ax.legend()
if show:
plt.show()
pyplot.show()

else:
raise errors.BrainPyError(f'Cannot visualize co-dimension {len(self.target_pars)} '
@@ -156,7 +160,7 @@ class Bifurcation2D(Num2DAnalyzer):
if C.F_vmap_jacobian not in self.analyzed_results:
f1 = lambda xy, *args: jnp.array([self.F_fx(xy[0], xy[1], *args),
self.F_fy(xy[0], xy[1], *args)])
f2 = bm.jit(bm.vmap(bm.jacobian(f1)), device=self.jit_device)
f2 = bm.jit(vmap(bm.jacobian(f1)), device=self.jit_device)
self.analyzed_results[C.F_vmap_jacobian] = f2
return self.analyzed_results[C.F_vmap_jacobian]

@@ -212,6 +216,8 @@ class Bifurcation2D(Num2DAnalyzer):
- parameters: a 2D matrix with the shape of (num_point, num_par)
- jacobians: a 3D tensors with the shape of (num_point, 2, 2)
"""
global pyplot
if pyplot is None: from matplotlib import pyplot
utils.output('I am making bifurcation analysis ...')

if self._can_convert_to_one_eq():
@@ -289,21 +295,21 @@ class Bifurcation2D(Num2DAnalyzer):

# visualization
for var in self.target_var_names:
plt.figure(var)
pyplot.figure(var)
for fp_type, points in container.items():
if len(points['p']):
plot_style = stability.plot_scheme[fp_type]
plt.plot(points['p'], points[var], '.', **plot_style, label=fp_type)
plt.xlabel(self.target_par_names[0])
plt.ylabel(var)
pyplot.plot(points['p'], points[var], '.', **plot_style, label=fp_type)
pyplot.xlabel(self.target_par_names[0])
pyplot.ylabel(var)

scale = (self.lim_scale - 1) / 2
plt.xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale))
plt.ylim(*utils.rescale(self.target_vars[var], scale=scale))
pyplot.xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale))
pyplot.ylim(*utils.rescale(self.target_vars[var], scale=scale))

plt.legend()
pyplot.legend()
if show:
plt.show()
pyplot.show()

# bifurcation analysis of co-dimension 2
elif len(self.target_pars) == 2:
@@ -320,7 +326,7 @@ class Bifurcation2D(Num2DAnalyzer):

# visualization
for var in self.target_var_names:
fig = plt.figure(var)
fig = pyplot.figure(var)
ax = fig.add_subplot(projection='3d')
for fp_type, points in container.items():
if len(points['p0']):
@@ -340,7 +346,7 @@ class Bifurcation2D(Num2DAnalyzer):
ax.grid(True)
ax.legend()
if show:
plt.show()
pyplot.show()

else:
raise ValueError('Unknown length of parameters.')
@@ -350,6 +356,8 @@ class Bifurcation2D(Num2DAnalyzer):

def plot_limit_cycle_by_sim(self, duration=100, with_plot=True, with_return=False,
plot_style=None, tol=0.001, show=False, dt=None, offset=1.):
global pyplot
if pyplot is None: from matplotlib import pyplot
utils.output('I am plotting the limit cycle ...')
if self._fixed_points is None:
utils.output('No fixed points found, you may call "plot_bifurcation(with_plot=True)" first.')
@@ -386,31 +394,33 @@ class Bifurcation2D(Num2DAnalyzer):
# visualization
if with_plot:
if plot_style is None: plot_style = dict()
fmt = plot_style.pop('fmt', '.')
fmt = plot_style.pop('fmt', '*')

if len(self.target_par_names) == 2:
for i, var in enumerate(self.target_var_names):
plt.figure(var)
plt.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['max'],
**plot_style, label='limit cycle (max)')
plt.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['min'],
**plot_style, label='limit cycle (min)')
plt.legend()
if len(ps_limit_cycle[0]):
for i, var in enumerate(self.target_var_names):
pyplot.figure(var)
pyplot.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['max'],
**plot_style, label='limit cycle (max)')
pyplot.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['min'],
**plot_style, label='limit cycle (min)')
pyplot.legend()

elif len(self.target_par_names) == 1:
for i, var in enumerate(self.target_var_names):
plt.figure(var)
plt.plot(ps_limit_cycle[0], vs_limit_cycle[i]['max'], fmt,
**plot_style, label='limit cycle (max)')
plt.plot(ps_limit_cycle[0], vs_limit_cycle[i]['min'], fmt,
**plot_style, label='limit cycle (min)')
plt.legend()
if len(ps_limit_cycle[0]):
for i, var in enumerate(self.target_var_names):
pyplot.figure(var)
pyplot.plot(ps_limit_cycle[0], vs_limit_cycle[i]['max'], fmt,
**plot_style, label='limit cycle (max)')
pyplot.plot(ps_limit_cycle[0], vs_limit_cycle[i]['min'], fmt,
**plot_style, label='limit cycle (min)')
pyplot.legend()

else:
raise errors.AnalyzerError

if show:
plt.show()
pyplot.show()

if with_return:
return vs_limit_cycle, ps_limit_cycle
@@ -437,6 +447,8 @@ class FastSlow1D(Bifurcation1D):

def plot_trajectory(self, initials, duration, plot_durations=None,
dt=None, show=False, with_plot=True, with_return=False):
global pyplot
if pyplot is None: from matplotlib import pyplot
utils.output('I am plotting the trajectory ...')

# check the initial values
@@ -470,14 +482,14 @@ class FastSlow1D(Bifurcation1D):
end = int(plot_durations[i][1] / dt)
p1_var = self.target_par_names[0]
if len(self.target_par_names) == 1:
lines = plt.plot(mon_res[self.x_var][start: end, i],
mon_res[p1_var][start: end, i], label=legend)
lines = pyplot.plot(mon_res[self.x_var][start: end, i],
mon_res[p1_var][start: end, i], label=legend)
elif len(self.target_par_names) == 2:
p2_var = self.target_par_names[1]
lines = plt.plot(mon_res[self.x_var][start: end, i],
mon_res[p1_var][start: end, i],
mon_res[p2_var][start: end, i],
label=legend)
lines = pyplot.plot(mon_res[self.x_var][start: end, i],
mon_res[p1_var][start: end, i],
mon_res[p2_var][start: end, i],
label=legend)
else:
raise ValueError
utils.add_arrow(lines[0])
@@ -488,10 +500,10 @@ class FastSlow1D(Bifurcation1D):
# scale = (self.lim_scale - 1.) / 2
# plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
# plt.ylim(*utils.rescale(self.target_vars[self.target_par_names[0]], scale=scale))
plt.legend()
pyplot.legend()

if show:
plt.show()
pyplot.show()

if with_return:
return mon_res
@@ -517,6 +529,8 @@ class FastSlow2D(Bifurcation2D):

def plot_trajectory(self, initials, duration, plot_durations=None,
dt=None, show=False, with_plot=True, with_return=False):
global pyplot
if pyplot is None: from matplotlib import pyplot
utils.output('I am plotting the trajectory ...')

# check the initial values
@@ -548,25 +562,25 @@ class FastSlow2D(Bifurcation2D):
end = int(plot_durations[i][1] / dt)

# visualization
plt.figure(self.x_var)
lines = plt.plot(mon_res[self.target_par_names[0]][start: end, i],
mon_res[self.x_var][start: end, i],
label=legend)
pyplot.figure(self.x_var)
lines = pyplot.plot(mon_res[self.target_par_names[0]][start: end, i],
mon_res[self.x_var][start: end, i],
label=legend)
utils.add_arrow(lines[0])

plt.figure(self.y_var)
lines = plt.plot(mon_res[self.target_par_names[0]][start: end, i],
mon_res[self.y_var][start: end, i],
label=legend)
pyplot.figure(self.y_var)
lines = pyplot.plot(mon_res[self.target_par_names[0]][start: end, i],
mon_res[self.y_var][start: end, i],
label=legend)
utils.add_arrow(lines[0])

plt.figure(self.x_var)
plt.legend()
plt.figure(self.y_var)
plt.legend()
pyplot.figure(self.x_var)
pyplot.legend()
pyplot.figure(self.y_var)
pyplot.legend()

if show:
plt.show()
pyplot.show()

if with_return:
return mon_res

+ 67
- 49
brainpy/analysis/lowdim/lowdim_phase_plane.py View File

@@ -1,14 +1,16 @@
# -*- coding: utf-8 -*-

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from jax import vmap

import brainpy.math as bm
from brainpy import errors, math
from brainpy.analysis import stability, constants as C, utils
from brainpy.analysis.lowdim.lowdim_analyzer import *

pyplot = None

__all__ = [
'PhasePlane1D',
'PhasePlane2D',
@@ -62,6 +64,8 @@ class PhasePlane1D(Num1DAnalyzer):

def plot_vector_field(self, show=False, with_plot=True, with_return=False):
"""Plot the vector filed."""
global pyplot
if pyplot is None: from matplotlib import pyplot
utils.output('I am creating the vector field ...')

# Nullcline of the x variable
@@ -72,19 +76,21 @@ class PhasePlane1D(Num1DAnalyzer):
if with_plot:
label = f"d{self.x_var}dt"
x_style = dict(color='lightcoral', alpha=.7, linewidth=4)
plt.plot(np.asarray(self.resolutions[self.x_var]), y_val, **x_style, label=label)
plt.axhline(0)
plt.xlabel(self.x_var)
plt.ylabel(label)
plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=(self.lim_scale - 1.) / 2))
plt.legend()
if show: plt.show()
pyplot.plot(np.asarray(self.resolutions[self.x_var]), y_val, **x_style, label=label)
pyplot.axhline(0)
pyplot.xlabel(self.x_var)
pyplot.ylabel(label)
pyplot.xlim(*utils.rescale(self.target_vars[self.x_var], scale=(self.lim_scale - 1.) / 2))
pyplot.legend()
if show: pyplot.show()
# return
if with_return:
return y_val

def plot_fixed_point(self, show=False, with_plot=True, with_return=False):
"""Plot the fixed point."""
global pyplot
if pyplot is None: from matplotlib import pyplot
utils.output('I am searching fixed points ...')

# fixed points and stability analysis
@@ -102,10 +108,10 @@ class PhasePlane1D(Num1DAnalyzer):
for fp_type, points in container.items():
if len(points):
plot_style = stability.plot_scheme[fp_type]
plt.plot(points, [0] * len(points), '.', markersize=20, **plot_style, label=fp_type)
plt.legend()
pyplot.plot(points, [0] * len(points), '.', markersize=20, **plot_style, label=fp_type)
pyplot.legend()
if show:
plt.show()
pyplot.show()

# return
if with_return:
@@ -153,7 +159,7 @@ class PhasePlane2D(Num2DAnalyzer):
@property
def F_vmap_brentq_fy(self):
if C.F_vmap_brentq_fy not in self.analyzed_results:
f_opt = bm.jit(bm.vmap(utils.jax_brentq(self.F_fy)))
f_opt = bm.jit(vmap(utils.jax_brentq(self.F_fy)))
self.analyzed_results[C.F_vmap_brentq_fy] = f_opt
return self.analyzed_results[C.F_vmap_brentq_fy]

@@ -178,6 +184,8 @@ class PhasePlane2D(Num2DAnalyzer):
"units", "angles", "scale". More settings please check
https://matplotlib.org/api/_as_gen/matplotlib.pyplot.quiver.html.
"""
global pyplot
if pyplot is None: from matplotlib import pyplot
utils.output('I am creating the vector field ...')

# get vector fields
@@ -197,7 +205,7 @@ class PhasePlane2D(Num2DAnalyzer):
speed = np.sqrt(dx ** 2 + dy ** 2)
dx = dx / speed
dy = dy / speed
plt.quiver(X, Y, dx, dy, **plot_style)
pyplot.quiver(X, Y, dx, dy, **plot_style)
elif plot_method == 'streamplot':
if plot_style is None:
plot_style = dict(arrowsize=1.2, density=1, color='thistle')
@@ -207,15 +215,15 @@ class PhasePlane2D(Num2DAnalyzer):
min_width, max_width = 0.5, 5.5
speed = np.nan_to_num(np.sqrt(dx ** 2 + dy ** 2))
linewidth = min_width + max_width * (speed / speed.max())
plt.streamplot(X, Y, dx, dy, linewidth=linewidth, **plot_style)
pyplot.streamplot(X, Y, dx, dy, linewidth=linewidth, **plot_style)
else:
raise errors.AnalyzerError(f'Unknown plot_method "{plot_method}", '
f'only supports "quiver" and "streamplot".')

plt.xlabel(self.x_var)
plt.ylabel(self.y_var)
pyplot.xlabel(self.x_var)
pyplot.ylabel(self.y_var)
if show:
plt.show()
pyplot.show()

if with_return: # return vector fields
return dx, dy
@@ -224,6 +232,8 @@ class PhasePlane2D(Num2DAnalyzer):
y_style=None, x_style=None, show=False,
coords=None, tol_nullcline=1e-7):
"""Plot the nullcline."""
global pyplot
if pyplot is None: from matplotlib import pyplot
utils.output('I am computing fx-nullcline ...')

if coords is None:
@@ -240,7 +250,7 @@ class PhasePlane2D(Num2DAnalyzer):
if x_style is None:
x_style = dict(color='cornflowerblue', alpha=.7, )
fmt = x_style.pop('fmt', '.')
plt.plot(x_values_in_fx, y_values_in_fx, fmt, **x_style, label=f"{self.x_var} nullcline")
pyplot.plot(x_values_in_fx, y_values_in_fx, fmt, **x_style, label=f"{self.x_var} nullcline")

# Nullcline of the y variable
utils.output('I am computing fy-nullcline ...')
@@ -252,17 +262,17 @@ class PhasePlane2D(Num2DAnalyzer):
if y_style is None:
y_style = dict(color='lightcoral', alpha=.7, )
fmt = y_style.pop('fmt', '.')
plt.plot(x_values_in_fy, y_values_in_fy, fmt, **y_style, label=f"{self.y_var} nullcline")
pyplot.plot(x_values_in_fy, y_values_in_fy, fmt, **y_style, label=f"{self.y_var} nullcline")

if with_plot:
plt.xlabel(self.x_var)
plt.ylabel(self.y_var)
pyplot.xlabel(self.x_var)
pyplot.ylabel(self.y_var)
scale = (self.lim_scale - 1.) / 2
plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
plt.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale))
plt.legend()
pyplot.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
pyplot.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale))
pyplot.legend()
if show:
plt.show()
pyplot.show()

if with_return:
return {self.x_var: (x_values_in_fx, y_values_in_fx),
@@ -273,6 +283,8 @@ class PhasePlane2D(Num2DAnalyzer):
select_candidates='fx-nullcline', num_rank=100, ):
"""Plot the fixed point and analyze its stability.
"""
global pyplot
if pyplot is None: from matplotlib import pyplot
utils.output('I am searching fixed points ...')

if self._can_convert_to_one_eq():
@@ -338,10 +350,10 @@ class PhasePlane2D(Num2DAnalyzer):
for fp_type, points in container.items():
if len(points['x']):
plot_style = stability.plot_scheme[fp_type]
plt.plot(points['x'], points['y'], '.', markersize=20, **plot_style, label=fp_type)
plt.legend()
pyplot.plot(points['x'], points['y'], '.', markersize=20, **plot_style, label=fp_type)
pyplot.legend()
if show:
plt.show()
pyplot.show()

if with_return:
return fixed_points
@@ -377,7 +389,8 @@ class PhasePlane2D(Num2DAnalyzer):
show : bool
Whether show or not.
"""

global pyplot
if pyplot is None: from matplotlib import pyplot
utils.output('I am plotting the trajectory ...')

if axes not in ['v-v', 't-v']:
@@ -413,28 +426,31 @@ class PhasePlane2D(Num2DAnalyzer):
start = int(plot_durations[i][0] / dt)
end = int(plot_durations[i][1] / dt)
if axes == 'v-v':
lines = plt.plot(mon_res[self.x_var][start: end, i], mon_res[self.y_var][start: end, i],
label=legend, **kwargs)
lines = pyplot.plot(mon_res[self.x_var][start: end, i],
mon_res[self.y_var][start: end, i],
label=legend, **kwargs)
utils.add_arrow(lines[0])
else:
plt.plot(mon_res.ts[start: end], mon_res[self.x_var][start: end, i],
label=legend + f', {self.x_var}', **kwargs)
plt.plot(mon_res.ts[start: end], mon_res[self.y_var][start: end, i],
label=legend + f', {self.y_var}', **kwargs)
pyplot.plot(mon_res.ts[start: end],
mon_res[self.x_var][start: end, i],
label=legend + f', {self.x_var}', **kwargs)
pyplot.plot(mon_res.ts[start: end],
mon_res[self.y_var][start: end, i],
label=legend + f', {self.y_var}', **kwargs)

# visualization of others
if axes == 'v-v':
plt.xlabel(self.x_var)
plt.ylabel(self.y_var)
pyplot.xlabel(self.x_var)
pyplot.ylabel(self.y_var)
scale = (self.lim_scale - 1.) / 2
plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
plt.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale))
plt.legend()
pyplot.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
pyplot.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale))
pyplot.legend()
else:
plt.legend(title='Initial values')
pyplot.legend(title='Initial values')

if show:
plt.show()
pyplot.show()

if with_return:
return mon_res
@@ -462,6 +478,8 @@ class PhasePlane2D(Num2DAnalyzer):
show : bool
Whether show or not.
"""
global pyplot
if pyplot is None: from matplotlib import pyplot
utils.output('I am plotting the limit cycle ...')

# 1. format the initial values
@@ -487,18 +505,18 @@ class PhasePlane2D(Num2DAnalyzer):
x_cycle = x_data[max_index[0]: max_index[1]]
y_cycle = y_data[max_index[0]: max_index[1]]
# 5.5 visualization
lines = plt.plot(x_cycle, y_cycle, label='limit cycle')
lines = pyplot.plot(x_cycle, y_cycle, label='limit cycle')
utils.add_arrow(lines[0])
else:
utils.output(f'No limit cycle found for initial value {initial}')

# 6. visualization
plt.xlabel(self.x_var)
plt.ylabel(self.y_var)
pyplot.xlabel(self.x_var)
pyplot.ylabel(self.y_var)
scale = (self.lim_scale - 1.) / 2
plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
plt.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale))
plt.legend()
pyplot.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
pyplot.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale))
pyplot.legend()

if show:
plt.show()
pyplot.show()

+ 4
- 2
brainpy/analysis/utils/measurement.py View File

@@ -2,6 +2,7 @@

import jax.numpy as jnp
import numpy as np
from brainpy.tools.others import numba_jit


__all__ = [
@@ -10,7 +11,7 @@ __all__ = [
]


# @tools.numba_jit
@numba_jit
def _f1(arr, grad, tol):
condition = np.logical_and(grad[:-1] * grad[1:] <= 0, grad[:-1] >= 0)
indexes = np.where(condition)[0]
@@ -19,7 +20,8 @@ def _f1(arr, grad, tol):
length = np.max(data) - np.min(data)
a = arr[indexes[-2]]
b = arr[indexes[-1]]
if np.abs(a - b) <= tol * length:
# TODO: how to choose length threshold, 1e-3?
if length > 1e-3 and np.abs(a - b) <= tol * length:
return indexes[-2:]
return np.array([-1, -1])



+ 1
- 2
brainpy/analysis/utils/model.py View File

@@ -49,8 +49,7 @@ def model_transform(model):
new_model = []
for intg in model:
if isinstance(intg.f, JointEq):
new_model.extend([type(intg)(eq, var_type=intg.var_type, dt=intg.dt, dyn_var=intg.dyn_var)
for eq in intg.f.eqs])
new_model.extend([type(intg)(eq, var_type=intg.var_type, dt=intg.dt) for eq in intg.f.eqs])
else:
new_model.append(intg)



+ 4
- 4
brainpy/analysis/utils/optimization.py View File

@@ -4,7 +4,7 @@
import jax.lax
import jax.numpy as jnp
import numpy as np
from jax import grad, jit
from jax import grad, jit, vmap
from jax.flatten_util import ravel_pytree

import brainpy.math as bm
@@ -197,7 +197,7 @@ def brentq_candidates(vmap_f, *values, args=()):

def brentq_roots(f, starts, ends, *vmap_args, args=()):
in_axes = (0, 0, tuple([0] * len(vmap_args)) + tuple([None] * len(args)))
vmap_f_opt = bm.jit(bm.vmap(jax_brentq(f), in_axes=in_axes))
vmap_f_opt = bm.jit(vmap(jax_brentq(f), in_axes=in_axes))
all_args = vmap_args + args
if len(all_args):
res = vmap_f_opt(starts, ends, all_args)
@@ -397,7 +397,7 @@ def roots_of_1d_by_x(f, candidates, args=()):
return fps
starts = candidates[candidate_ids]
ends = candidates[candidate_ids + 1]
f_opt = bm.jit(bm.vmap(jax_brentq(f), in_axes=(0, 0, None)))
f_opt = bm.jit(vmap(jax_brentq(f), in_axes=(0, 0, None)))
res = f_opt(starts, ends, args)
valid_idx = jnp.where(res['status'] == ECONVERGED)[0]
fps2 = res['root'][valid_idx]
@@ -406,7 +406,7 @@ def roots_of_1d_by_x(f, candidates, args=()):

def roots_of_1d_by_xy(f, starts, ends, args):
f = f_without_jaxarray_return(f)
f_opt = bm.jit(bm.vmap(jax_brentq(f)))
f_opt = bm.jit(vmap(jax_brentq(f)))
res = f_opt(starts, ends, (args,))
valid_idx = jnp.where(res['status'] == ECONVERGED)[0]
xs = res['root'][valid_idx]


+ 2
- 1
brainpy/analysis/utils/others.py View File

@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-

import jax.numpy as jnp
from jax import vmap
import numpy as np

import brainpy.math as bm
@@ -76,7 +77,7 @@ def get_sign(f, xs, ys):

def get_sign2(f, *xyz, args=()):
in_axes = tuple(range(len(xyz))) + tuple([None] * len(args))
f = bm.jit(bm.vmap(f_without_jaxarray_return(f), in_axes=in_axes))
f = bm.jit(vmap(f_without_jaxarray_return(f), in_axes=in_axes))
xyz = tuple((v.value if isinstance(v, bm.JaxArray) else v) for v in xyz)
XYZ = jnp.meshgrid(*xyz)
XYZ = tuple(jnp.moveaxis(v, 1, 0).flatten() for v in XYZ)


+ 27
- 0
brainpy/check.py View File

@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-


__all__ = [
'is_checking',
'turn_on',
'turn_off',
]

_check = True


def is_checking():
"""Whether the checking is turn on."""
return _check


def turn_on():
"""Turn on the checking."""
global _check
_check = True


def turn_off():
"""Turn off the checking."""
global _check
_check = False

+ 0
- 12
brainpy/compact/models.py View File

@@ -1,12 +0,0 @@
# -*- coding: utf-8 -*-

from brainpy.dyn import LIF, AdExIF, Izhikevich, ExpCOBA, ExpCUBA, DeltaSynapse

__all__ = [
'LIF',
'AdExIF',
'Izhikevich',
'ExpCOBA',
'ExpCUBA',
'DeltaSynapse',
]

+ 0
- 11
brainpy/compact/runners.py View File

@@ -1,11 +0,0 @@
# -*- coding: utf-8 -*-

from brainpy.integrators.runner import IntegratorRunner
from brainpy.dyn.runners import DSRunner, StructRunner, ReportRunner

__all__ = [
'IntegratorRunner',
'DSRunner',
'StructRunner',
'ReportRunner'
]

brainpy/compact/__init__.py → brainpy/compat/__init__.py View File


brainpy/compact/brainobjects.py → brainpy/compat/brainobjects.py View File

@@ -15,6 +15,11 @@ __all__ = [


class DynamicalSystem(dyn.DynamicalSystem):
"""Dynamical System.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.DynamicalSystem" instead.
"""
def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.DynamicalSystem" instead. '
'"brainpy.DynamicalSystem" is deprecated since '
@@ -23,6 +28,11 @@ class DynamicalSystem(dyn.DynamicalSystem):


class Container(dyn.Container):
"""Container.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.Container" instead.
"""
def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.Container" instead. '
'"brainpy.Container" is deprecated since '
@@ -31,6 +41,11 @@ class Container(dyn.Container):


class Network(dyn.Network):
"""Network.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.Network" instead.
"""
def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.Network" instead. '
'"brainpy.Network" is deprecated since '
@@ -39,6 +54,11 @@ class Network(dyn.Network):


class ConstantDelay(dyn.ConstantDelay):
"""Constant Delay.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.ConstantDelay" instead.
"""
def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.ConstantDelay" instead. '
'"brainpy.ConstantDelay" is deprecated since '
@@ -47,6 +67,11 @@ class ConstantDelay(dyn.ConstantDelay):


class NeuGroup(dyn.NeuGroup):
"""Neuron group.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.NeuGroup" instead.
"""
def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.NeuGroup" instead. '
'"brainpy.NeuGroup" is deprecated since '
@@ -55,6 +80,11 @@ class NeuGroup(dyn.NeuGroup):


class TwoEndConn(dyn.TwoEndConn):
"""Two-end synaptic connection.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.TwoEndConn" instead.
"""
def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.TwoEndConn" instead. '
'"brainpy.TwoEndConn" is deprecated since '

brainpy/compact/integrators.py → brainpy/compat/integrators.py View File

@@ -13,6 +13,11 @@ __all__ = [


def set_default_odeint(method):
"""Set default ode integrator.

.. deprecated:: 2.1.0
Please use "brainpy.ode.set_default_odeint" instead.
"""
warnings.warn('Please use "brainpy.ode.set_default_odeint" instead. '
'"brainpy.set_default_odeint" is deprecated since '
'version 2.1.0', DeprecationWarning)
@@ -20,6 +25,11 @@ def set_default_odeint(method):


def get_default_odeint():
"""Get default ode integrator.

.. deprecated:: 2.1.0
Please use "brainpy.ode.get_default_odeint" instead.
"""
warnings.warn('Please use "brainpy.ode.get_default_odeint" instead. '
'"brainpy.get_default_odeint" is deprecated since '
'version 2.1.0', DeprecationWarning)
@@ -27,6 +37,11 @@ def get_default_odeint():


def set_default_sdeint(method):
"""Set default sde integrator.

.. deprecated:: 2.1.0
Please use "brainpy.ode.set_default_sdeint" instead.
"""
warnings.warn('Please use "brainpy.sde.set_default_sdeint" instead. '
'"brainpy.set_default_sdeint" is deprecated since '
'version 2.1.0', DeprecationWarning)
@@ -34,6 +49,11 @@ def set_default_sdeint(method):


def get_default_sdeint():
"""Get default sde integrator.

.. deprecated:: 2.1.0
Please use "brainpy.ode.get_default_sdeint" instead.
"""
warnings.warn('Please use "brainpy.sde.get_default_sdeint" instead. '
'"brainpy.get_default_sdeint" is deprecated since '
'version 2.1.0', DeprecationWarning)

brainpy/compact/layers.py → brainpy/compat/layers.py View File

@@ -23,7 +23,10 @@ def _check_args(args):


class Module(Base):
"""Basic module class."""
"""Basic module class.

.. deprecated:: 2.1.0
"""

@staticmethod
def get_param(param, size):
@@ -47,7 +50,7 @@ class Module(Base):
def __init__(self, name=None): # initialize parameters
warnings.warn('Please use "brainpy.rnns.Module" instead. '
'"brainpy.layers.Module" is deprecated since '
'version 2.0.3.', DeprecationWarning)
'version 2.1.0.', DeprecationWarning)
super(Module, self).__init__(name=name)

def __call__(self, *args, **kwargs): # initialize variables

+ 98
- 0
brainpy/compat/models.py View File

@@ -0,0 +1,98 @@
# -*- coding: utf-8 -*-

import warnings

from brainpy.dyn import neurons, synapses

__all__ = [
'LIF',
'AdExIF',
'Izhikevich',
'ExpCOBA',
'ExpCUBA',
'DeltaSynapse',
]


class LIF(neurons.LIF):
"""LIF neuron model.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.LIF" instead.
"""

def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.LIF" instead. '
'"brainpy.models.LIF" is deprecated since '
'version 2.1.0', DeprecationWarning)
super(LIF, self).__init__(*args, **kwargs)


class AdExIF(neurons.AdExIF):
"""AdExIF neuron model.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.AdExIF" instead.
"""

def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.AdExIF" instead. '
'"brainpy.models.AdExIF" is deprecated since '
'version 2.1.0', DeprecationWarning)
super(AdExIF, self).__init__(*args, **kwargs)


class Izhikevich(neurons.Izhikevich):
"""Izhikevich neuron model.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.Izhikevich" instead.
"""

def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.Izhikevich" instead. '
'"brainpy.models.Izhikevich" is deprecated since '
'version 2.1.0', DeprecationWarning)
super(Izhikevich, self).__init__(*args, **kwargs)


class ExpCOBA(synapses.ExpCOBA):
"""ExpCOBA synapse model.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.ExpCOBA" instead.
"""

def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.ExpCOBA" instead. '
'"brainpy.models.ExpCOBA" is deprecated since '
'version 2.1.0', DeprecationWarning)
super(ExpCOBA, self).__init__(*args, **kwargs)


class ExpCUBA(synapses.ExpCUBA):
"""ExpCUBA synapse model.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.ExpCUBA" instead.
"""

def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.ExpCUBA" instead. '
'"brainpy.models.ExpCUBA" is deprecated since '
'version 2.1.0', DeprecationWarning)
super(ExpCUBA, self).__init__(*args, **kwargs)


class DeltaSynapse(synapses.DeltaSynapse):
"""Delta synapse model.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.DeltaSynapse" instead.
"""

def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.DeltaSynapse" instead. '
'"brainpy.models.DeltaSynapse" is deprecated since '
'version 2.1.0', DeprecationWarning)
super(DeltaSynapse, self).__init__(*args, **kwargs)

brainpy/compact/monitor.py → brainpy/compat/monitor.py View File

@@ -9,8 +9,13 @@ __all__ = [


class Monitor(monitor.Monitor):
"""Monitor class.

.. deprecated:: 2.1.0
Please use "brainpy.running.Monitor" instead.
"""
def __init__(self, *args, **kwargs):
super(Monitor, self).__init__(*args, **kwargs)
warnings.warn('Please use "brainpy.running.Monitor" instead. '
'"brainpy.Monitor" is deprecated since version 2.0.3.',
'"brainpy.Monitor" is deprecated since version 2.1.0.',
DeprecationWarning)
super(Monitor, self).__init__(*args, **kwargs)

+ 65
- 0
brainpy/compat/runners.py View File

@@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-

import warnings

from brainpy.dyn import runners as dyn_runner
from brainpy.integrators import runner as intg_runner

__all__ = [
'IntegratorRunner',
'DSRunner',
'StructRunner',
'ReportRunner'
]


class IntegratorRunner(intg_runner.IntegratorRunner):
"""Integrator runner class.

.. deprecated:: 2.1.0
Please use "brainpy.integrators.IntegratorRunner" instead.
"""
def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.integrators.IntegratorRunner" instead. '
'"brainpy.IntegratorRunner" is deprecated since '
'version 2.1.0', DeprecationWarning)
super(IntegratorRunner, self).__init__(*args, **kwargs)


class DSRunner(dyn_runner.DSRunner):
"""Dynamical system runner class.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.DSRunner" instead.
"""
def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.DSRunner" instead. '
'"brainpy.DSRunner" is deprecated since '
'version 2.1.0', DeprecationWarning)
super(DSRunner, self).__init__(*args, **kwargs)


class StructRunner(dyn_runner.StructRunner):
"""Dynamical system runner class.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.StructRunner" instead.
"""
def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.StructRunner" instead. '
'"brainpy.StructRunner" is deprecated since '
'version 2.1.0', DeprecationWarning)
super(StructRunner, self).__init__(*args, **kwargs)


class ReportRunner(dyn_runner.ReportRunner):
"""Dynamical system runner class.

.. deprecated:: 2.1.0
Please use "brainpy.dyn.ReportRunner" instead.
"""
def __init__(self, *args, **kwargs):
warnings.warn('Please use "brainpy.dyn.ReportRunner" instead. '
'"brainpy.ReportRunner" is deprecated since '
'version 2.1.0', DeprecationWarning)
super(ReportRunner, self).__init__(*args, **kwargs)

+ 2
- 2
brainpy/connect/tests/test_regular_conn.py View File

@@ -14,7 +14,7 @@ def test_one2one():
num = bp.tools.size2num(size)

actual_mat = bp.math.zeros((num, num), dtype=bp.math.bool_)
actual_mat = bp.math.fill_diagonal(actual_mat, True)
bp.math.fill_diagonal(actual_mat, True)

assert bp.math.array_equal(actual_mat, conn_mat)
assert bp.math.array_equal(pre_ids, bp.math.arange(num))
@@ -42,7 +42,7 @@ def test_all2all():
print(mat)
actual_mat = bp.math.ones((num, num), dtype=bp.math.bool_)
if not has_self:
actual_mat = bp.math.fill_diagonal(actual_mat, False)
bp.math.fill_diagonal(actual_mat, False)

assert bp.math.array_equal(actual_mat, mat)



+ 2
- 2
brainpy/datasets/chaotic_systems.py View File

@@ -167,8 +167,8 @@ def mackey_glass_series(duration, dt=0.1, beta=2., gamma=1., tau=2., n=9.65,
assert isinstance(inits, (bm.ndarray, jnp.ndarray))

rng = bm.random.RandomState(seed)
xdelay = bm.FixedLenDelay(inits.shape, tau, dt=dt)
xdelay.data = inits + 0.2 * (rng.random((xdelay.num_delay_steps,) + inits.shape) - 0.5)
xdelay = bm.TimeDelay(inits, tau, dt=dt)
xdelay.data = inits + 0.2 * (rng.random((xdelay.num_delay_step,) + inits.shape) - 0.5)

@ddeint(method=method, state_delays={'x': xdelay})
def mg_eq(x, t):


+ 0
- 752
brainpy/dyn/neurons/IF_models.py View File

@@ -1,752 +0,0 @@
# -*- coding: utf-8 -*-

import brainpy.math as bm
from brainpy.integrators.joint_eq import JointEq
from brainpy.integrators.ode import odeint
from brainpy.dyn.base import NeuGroup

__all__ = [
'LIF',
'ExpIF',
'AdExIF',
'QuaIF',
'AdQuaIF',
'GIF',
]


class LIF(NeuGroup):
r"""Leaky integrate-and-fire neuron model.

**Model Descriptions**

The formal equations of a LIF model [1]_ is given by:

.. math::

\tau \frac{dV}{dt} = - (V(t) - V_{rest}) + I(t) \\
\text{after} \quad V(t) \gt V_{th}, V(t) = V_{reset} \quad
\text{last} \quad \tau_{ref} \quad \text{ms}

where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting
membrane potential, :math:`V_{reset}` is the reset membrane potential,
:math:`V_{th}` is the spike threshold, :math:`\tau` is the time constant,
:math:`\tau_{ref}` is the refractory time period,
and :math:`I` is the time-variant synaptic inputs.

**Model Examples**

- `(Brette, Romain. 2004) LIF phase locking <https://brainpy-examples.readthedocs.io/en/latest/neurons/Romain_2004_LIF_phase_locking.html>`_

**Model Parameters**

============= ============== ======== =========================================
**Parameter** **Init Value** **Unit** **Explanation**
------------- -------------- -------- -----------------------------------------
V_rest 0 mV Resting membrane potential.
V_reset -5 mV Reset potential after spike.
V_th 20 mV Threshold potential of spike.
tau 10 ms Membrane time constant. Compute by R * C.
tau_ref 5 ms Refractory period length.(ms)
============= ============== ======== =========================================

**Neuron Variables**

================== ================= =========================================================
**Variables name** **Initial Value** **Explanation**
------------------ ----------------- ---------------------------------------------------------
V 0 Membrane potential.
input 0 External and synaptic input current.
spike False Flag to mark whether the neuron is spiking.
refractory False Flag to mark whether the neuron is in refractory period.
t_last_spike -1e7 Last spike time stamp.
================== ================= =========================================================

**References**

.. [1] Abbott, Larry F. "Lapicque’s introduction of the integrate-and-fire model
neuron (1907)." Brain research bulletin 50, no. 5-6 (1999): 303-304.
"""

def __init__(self, size, V_rest=0., V_reset=-5., V_th=20., tau=10.,
tau_ref=1., method='exp_auto', name=None):
# initialization
super(LIF, self).__init__(size=size, name=name)

# parameters
self.V_rest = V_rest
self.V_reset = V_reset
self.V_th = V_th
self.tau = tau
self.tau_ref = tau_ref

# variables
self.V = bm.Variable(bm.zeros(self.num))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))

# integral
self.integral = odeint(method=method, f=self.derivative)

def derivative(self, V, t, I_ext):
dvdt = (-V + self.V_rest + I_ext) / self.tau
return dvdt

def update(self, _t, _dt):
refractory = (_t - self.t_last_spike) <= self.tau_ref
V = self.integral(self.V, _t, self.input, dt=_dt)
V = bm.where(refractory, self.V, V)
spike = V >= self.V_th
self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike)
self.V.value = bm.where(spike, self.V_reset, V)
self.refractory.value = bm.logical_or(refractory, spike)
self.spike.value = spike
self.input[:] = 0.


class ExpIF(NeuGroup):
r"""Exponential integrate-and-fire neuron model.

**Model Descriptions**

In the exponential integrate-and-fire model [1]_, the differential
equation for the membrane potential is given by

.. math::

\tau\frac{d V}{d t}= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} + RI(t), \\
\text{after} \, V(t) \gt V_{th}, V(t) = V_{reset} \, \text{last} \, \tau_{ref} \, \text{ms}

This equation has an exponential nonlinearity with "sharpness" parameter :math:`\Delta_{T}`
and "threshold" :math:`\vartheta_{rh}`.

The moment when the membrane potential reaches the numerical threshold :math:`V_{th}`
defines the firing time :math:`t^{(f)}`. After firing, the membrane potential is reset to
:math:`V_{rest}` and integration restarts at time :math:`t^{(f)}+\tau_{\rm ref}`,
where :math:`\tau_{\rm ref}` is an absolute refractory time.
If the numerical threshold is chosen sufficiently high, :math:`V_{th}\gg v+\Delta_T`,
its exact value does not play any role. The reason is that the upswing of the action
potential for :math:`v\gg v +\Delta_{T}` is so rapid, that it goes to infinity in
an incredibly short time. The threshold :math:`V_{th}` is introduced mainly for numerical
convenience. For a formal mathematical analysis of the model, the threshold can be pushed
to infinity.

The model was first introduced by Nicolas Fourcaud-Trocmé, David Hansel, Carl van Vreeswijk
and Nicolas Brunel [1]_. The exponential nonlinearity was later confirmed by Badel et al. [3]_.
It is one of the prominent examples of a precise theoretical prediction in computational
neuroscience that was later confirmed by experimental neuroscience.

Two important remarks:

- (i) The right-hand side of the above equation contains a nonlinearity
that can be directly extracted from experimental data [3]_. In this sense the exponential
nonlinearity is not an arbitrary choice but directly supported by experimental evidence.
- (ii) Even though it is a nonlinear model, it is simple enough to calculate the firing
rate for constant input, and the linear response to fluctuations, even in the presence
of input noise [4]_.

**Model Examples**

.. plot::
:include-source: True

>>> import brainpy as bp
>>> group = bp.dyn.ExpIF(1)
>>> runner = bp.dyn.DSRunner(group, monitors=['V'], inputs=('input', 10.))
>>> runner.run(300., )
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V', show=True)


**Model Parameters**

============= ============== ======== ===================================================
**Parameter** **Init Value** **Unit** **Explanation**
------------- -------------- -------- ---------------------------------------------------
V_rest -65 mV Resting potential.
V_reset -68 mV Reset potential after spike.
V_th -30 mV Threshold potential of spike.
V_T -59.9 mV Threshold potential of generating action potential.
delta_T 3.48 \ Spike slope factor.
R 1 \ Membrane resistance.
tau 10 \ Membrane time constant. Compute by R * C.
tau_ref 1.7 \ Refractory period length.
============= ============== ======== ===================================================

**Model Variables**

================== ================= =========================================================
**Variables name** **Initial Value** **Explanation**
------------------ ----------------- ---------------------------------------------------------
V 0 Membrane potential.
input 0 External and synaptic input current.
spike False Flag to mark whether the neuron is spiking.
refractory False Flag to mark whether the neuron is in refractory period.
t_last_spike -1e7 Last spike time stamp.
================== ================= =========================================================

**References**

.. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation
mechanisms determine the neuronal response to fluctuating
inputs." Journal of Neuroscience 23.37 (2003): 11628-11640.
.. [2] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014).
Neuronal dynamics: From single neurons to networks and models
of cognition. Cambridge University Press.
.. [3] Badel, Laurent, Sandrine Lefort, Romain Brette, Carl CH Petersen,
Wulfram Gerstner, and Magnus JE Richardson. "Dynamic IV curves
are reliable predictors of naturalistic pyramidal-neuron voltage
traces." Journal of Neurophysiology 99, no. 2 (2008): 656-666.
.. [4] Richardson, Magnus JE. "Firing-rate response of linear and nonlinear
integrate-and-fire neurons to modulated current-based and
conductance-based synaptic drive." Physical Review E 76, no. 2 (2007): 021919.
.. [5] https://en.wikipedia.org/wiki/Exponential_integrate-and-fire
"""

def __init__(self, size, V_rest=-65., V_reset=-68., V_th=-30., V_T=-59.9, delta_T=3.48,
R=1., tau=10., tau_ref=1.7, method='exp_auto', name=None):
# initialize
super(ExpIF, self).__init__(size=size, name=name)

# parameters
self.V_rest = V_rest
self.V_reset = V_reset
self.V_th = V_th
self.V_T = V_T
self.delta_T = delta_T
self.R = R
self.tau = tau
self.tau_ref = tau_ref

# variables
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
# variables
self.V = bm.Variable(bm.zeros(self.num))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)

# integral
self.integral = odeint(method=method, f=self.derivative)

def derivative(self, V, t, I_ext):
exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T)
dvdt = (- (V - self.V_rest) + exp_v + self.R * I_ext) / self.tau
return dvdt

def update(self, _t, _dt):
refractory = (_t - self.t_last_spike) <= self.tau_ref
V = self.integral(self.V, _t, self.input, dt=_dt)
V = bm.where(refractory, self.V, V)
spike = self.V_th <= V
self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike)
self.V.value = bm.where(spike, self.V_reset, V)
self.refractory.value = bm.logical_or(refractory, spike)
self.spike.value = spike
self.input[:] = 0.


class AdExIF(NeuGroup):
r"""Adaptive exponential integrate-and-fire neuron model.

**Model Descriptions**

The **adaptive exponential integrate-and-fire model**, also called AdEx, is a
spiking neuron model with two variables [1]_ [2]_.

.. math::

\begin{aligned}
\tau_m\frac{d V}{d t} &= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} - Rw + RI(t), \\
\tau_w \frac{d w}{d t} &=a(V-V_{rest}) - w
\end{aligned}

once the membrane potential reaches the spike threshold,

.. math::

V \rightarrow V_{reset}, \\
w \rightarrow w+b.

The first equation describes the dynamics of the membrane potential and includes
an activation term with an exponential voltage dependence. Voltage is coupled to
a second equation which describes adaptation. Both variables are reset if an action
potential has been triggered. The combination of adaptation and exponential voltage
dependence gives rise to the name Adaptive Exponential Integrate-and-Fire model.

The adaptive exponential integrate-and-fire model is capable of describing known
neuronal firing patterns, e.g., adapting, bursting, delayed spike initiation,
initial bursting, fast spiking, and regular spiking.

**Model Examples**

- `Examples for different firing patterns <https://brainpy-examples.readthedocs.io/en/latest/neurons/AdExIF_model.html>`_

**Model Parameters**

============= ============== ======== ========================================================================================================================
**Parameter** **Init Value** **Unit** **Explanation**
------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------
V_rest -65 mV Resting potential.
V_reset -68 mV Reset potential after spike.
V_th -30 mV Threshold potential of spike and reset.
V_T -59.9 mV Threshold potential of generating action potential.
delta_T 3.48 \ Spike slope factor.
a 1 \ The sensitivity of the recovery variable :math:`u` to the sub-threshold fluctuations of the membrane potential :math:`v`
b 1 \ The increment of :math:`w` produced by a spike.
R 1 \ Membrane resistance.
tau 10 ms Membrane time constant. Compute by R * C.
tau_w 30 ms Time constant of the adaptation current.
============= ============== ======== ========================================================================================================================

**Model Variables**

================== ================= =========================================================
**Variables name** **Initial Value** **Explanation**
------------------ ----------------- ---------------------------------------------------------
V 0 Membrane potential.
w 0 Adaptation current.
input 0 External and synaptic input current.
spike False Flag to mark whether the neuron is spiking.
t_last_spike -1e7 Last spike time stamp.
================== ================= =========================================================

**References**

.. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation
mechanisms determine the neuronal response to fluctuating
inputs." Journal of Neuroscience 23.37 (2003): 11628-11640.
.. [2] http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model
"""

def __init__(self, size, V_rest=-65., V_reset=-68., V_th=-30., V_T=-59.9, delta_T=3.48, a=1.,
b=1., tau=10., tau_w=30., R=1., method='exp_auto', name=None):
super(AdExIF, self).__init__(size=size, name=name)

# parameters
self.V_rest = V_rest
self.V_reset = V_reset
self.V_th = V_th
self.V_T = V_T
self.delta_T = delta_T
self.a = a
self.b = b
self.tau = tau
self.tau_w = tau_w
self.R = R

# variables
self.w = bm.Variable(bm.zeros(self.num))
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
self.V = bm.Variable(bm.zeros(self.num))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)

# functions
self.integral = odeint(method=method, f=self.derivative)

def dV(self, V, t, w, I_ext):
dVdt = (- V + self.V_rest + self.delta_T * bm.exp((V - self.V_T) / self.delta_T) -
self.R * w + self.R * I_ext) / self.tau
return dVdt

def dw(self, w, t, V):
dwdt = (self.a * (V - self.V_rest) - w) / self.tau_w
return dwdt

@property
def derivative(self):
return JointEq([self.dV, self.dw])

def update(self, _t, _dt):
V, w = self.integral(self.V, self.w, _t, self.input, dt=_dt)
spike = V >= self.V_th
self.t_last_spike[:] = bm.where(spike, _t, self.t_last_spike)
self.V.value = bm.where(spike, self.V_reset, V)
self.w.value = bm.where(spike, w + self.b, w)
self.spike.value = spike
self.input[:] = 0.


class QuaIF(NeuGroup):
r"""Quadratic Integrate-and-Fire neuron model.

**Model Descriptions**

In contrast to physiologically accurate but computationally expensive
neuron models like the Hodgkin–Huxley model, the QIF model [1]_ seeks only
to produce **action potential-like patterns** and ignores subtleties
like gating variables, which play an important role in generating action
potentials in a real neuron. However, the QIF model is incredibly easy
to implement and compute, and relatively straightforward to study and
understand, thus has found ubiquitous use in computational neuroscience.

.. math::

\tau \frac{d V}{d t}=c(V-V_{rest})(V-V_c) + RI(t)

where the parameters are taken to be :math:`c` =0.07, and :math:`V_c = -50 mV` (Latham et al., 2000).

**Model Examples**

.. plot::
:include-source: True

>>> import brainpy as bp
>>>
>>> group = bp.dyn.QuaIF(1,)
>>>
>>> runner = bp.dyn.DSRunner(group, monitors=['V'], inputs=('input', 20.))
>>> runner.run(duration=200.)
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)


**Model Parameters**

============= ============== ======== ========================================================================================================================
**Parameter** **Init Value** **Unit** **Explanation**
------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------
V_rest -65 mV Resting potential.
V_reset -68 mV Reset potential after spike.
V_th -30 mV Threshold potential of spike and reset.
V_c -50 mV Critical voltage for spike initiation. Must be larger than V_rest.
c .07 \ Coefficient describes membrane potential update. Larger than 0.
R 1 \ Membrane resistance.
tau 10 ms Membrane time constant. Compute by R * C.
tau_ref 0 ms Refractory period length.
============= ============== ======== ========================================================================================================================

**Model Variables**

================== ================= =========================================================
**Variables name** **Initial Value** **Explanation**
------------------ ----------------- ---------------------------------------------------------
V 0 Membrane potential.
input 0 External and synaptic input current.
spike False Flag to mark whether the neuron is spiking.
refractory False Flag to mark whether the neuron is in refractory period.
t_last_spike -1e7 Last spike time stamp.
================== ================= =========================================================

**References**

.. [1] P. E. Latham, B.J. Richmond, P. Nelson and S. Nirenberg
(2000) Intrinsic dynamics in neuronal networks. I. Theory.
J. Neurophysiology 83, pp. 808–827.
"""

def __init__(self, size, V_rest=-65., V_reset=-68., V_th=-30., V_c=-50.0, c=.07,
R=1., tau=10., tau_ref=0., method='exp_auto', name=None):
# initialization
super(QuaIF, self).__init__(size=size, name=name)

# parameters
self.V_rest = V_rest
self.V_reset = V_reset
self.V_th = V_th
self.V_c = V_c
self.c = c
self.R = R
self.tau = tau
self.tau_ref = tau_ref

# variables
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
# variables
self.V = bm.Variable(bm.zeros(self.num))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)

# integral
self.integral = odeint(method=method, f=self.derivative)

def derivative(self, V, t, I_ext):
dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I_ext) / self.tau
return dVdt

def update(self, _t, _dt, **kwargs):
refractory = (_t - self.t_last_spike) <= self.tau_ref
V = self.integral(self.V, _t, self.input, dt=_dt)
V = bm.where(refractory, self.V, V)
spike = self.V_th <= V
self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike)
self.V.value = bm.where(spike, self.V_reset, V)
self.refractory.value = bm.logical_or(refractory, spike)
self.spike.value = spike
self.input[:] = 0.


class AdQuaIF(NeuGroup):
r"""Adaptive quadratic integrate-and-fire neuron model.

**Model Descriptions**

The adaptive quadratic integrate-and-fire neuron model [1]_ is given by:

.. math::

\begin{aligned}
\tau_m \frac{d V}{d t}&=c(V-V_{rest})(V-V_c) - w + I(t), \\
\tau_w \frac{d w}{d t}&=a(V-V_{rest}) - w,
\end{aligned}

once the membrane potential reaches the spike threshold,

.. math::

V \rightarrow V_{reset}, \\
w \rightarrow w+b.

**Model Examples**

.. plot::
:include-source: True

>>> import brainpy as bp
>>> group = bp.dyn.AdQuaIF(1, )
>>> runner = bp.dyn.DSRunner(group, monitors=['V', 'w'], inputs=('input', 30.))
>>> runner.run(300)
>>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
>>> fig.add_subplot(gs[0, 0])
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V')
>>> fig.add_subplot(gs[1, 0])
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.w, ylabel='w', show=True)

**Model Parameters**

============= ============== ======== =======================================================
**Parameter** **Init Value** **Unit** **Explanation**
------------- -------------- -------- -------------------------------------------------------
V_rest -65 mV Resting potential.
V_reset -68 mV Reset potential after spike.
V_th -30 mV Threshold potential of spike and reset.
V_c -50 mV Critical voltage for spike initiation. Must be larger
than :math:`V_{rest}`.
a 1 \ The sensitivity of the recovery variable :math:`u` to
the sub-threshold fluctuations of the membrane
potential :math:`v`
b .1 \ The increment of :math:`w` produced by a spike.
c .07 \ Coefficient describes membrane potential update.
Larger than 0.
tau 10 ms Membrane time constant.
tau_w 10 ms Time constant of the adaptation current.
============= ============== ======== =======================================================

**Model Variables**

================== ================= ==========================================================
**Variables name** **Initial Value** **Explanation**
------------------ ----------------- ----------------------------------------------------------
V 0 Membrane potential.
w 0 Adaptation current.
input 0 External and synaptic input current.
spike False Flag to mark whether the neuron is spiking.
t_last_spike -1e7 Last spike time stamp.
================== ================= ==========================================================

**References**

.. [1] Izhikevich, E. M. (2004). Which model to use for cortical spiking
neurons?. IEEE transactions on neural networks, 15(5), 1063-1070.
.. [2] Touboul, Jonathan. "Bifurcation analysis of a general class of
nonlinear integrate-and-fire neurons." SIAM Journal on Applied
Mathematics 68, no. 4 (2008): 1045-1079.
"""

def __init__(self, size, V_rest=-65., V_reset=-68., V_th=-30., V_c=-50.0, a=1., b=.1,
c=.07, tau=10., tau_w=10., method='exp_auto', name=None):
super(AdQuaIF, self).__init__(size=size, name=name)

# parameters
self.V_rest = V_rest
self.V_reset = V_reset
self.V_th = V_th
self.V_c = V_c
self.c = c
self.a = a
self.b = b
self.tau = tau
self.tau_w = tau_w

# variables
self.V = bm.Variable(bm.zeros(self.num))
self.w = bm.Variable(bm.zeros(self.num))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))

# integral
self.integral = odeint(method=method, f=self.derivative)

def dV(self, V, t, w, I_ext):
dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + I_ext) / self.tau
return dVdt

def dw(self, w, t, V):
dwdt = (self.a * (V - self.V_rest) - w) / self.tau_w
return dwdt

@property
def derivative(self):
return JointEq([self.dV, self.dw])

def update(self, _t, _dt):
V, w = self.integral(self.V, self.w, _t, self.input, dt=_dt)
spike = self.V_th <= V
self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike)
self.V.value = bm.where(spike, self.V_reset, V)
self.w.value = bm.where(spike, w + self.b, w)
self.spike.value = spike
self.input[:] = 0.


class GIF(NeuGroup):
r"""Generalized Integrate-and-Fire model.

**Model Descriptions**

The generalized integrate-and-fire model [1]_ is given by

.. math::

&\frac{d I_j}{d t} = - k_j I_j

&\frac{d V}{d t} = ( - (V - V_{rest}) + R\sum_{j}I_j + RI) / \tau

&\frac{d V_{th}}{d t} = a(V - V_{rest}) - b(V_{th} - V_{th\infty})

When :math:`V` meet :math:`V_{th}`, Generalized IF neuron fires:

.. math::

&I_j \leftarrow R_j I_j + A_j

&V \leftarrow V_{reset}

&V_{th} \leftarrow max(V_{th_{reset}}, V_{th})

Note that :math:`I_j` refers to arbitrary number of internal currents.

**Model Examples**

- `Detailed examples to reproduce different firing patterns <https://brainpy-examples.readthedocs.io/en/latest/neurons/Niebur_2009_GIF.html>`_

**Model Parameters**

============= ============== ======== ====================================================================
**Parameter** **Init Value** **Unit** **Explanation**
------------- -------------- -------- --------------------------------------------------------------------
V_rest -70 mV Resting potential.
V_reset -70 mV Reset potential after spike.
V_th_inf -50 mV Target value of threshold potential :math:`V_{th}` updating.
V_th_reset -60 mV Free parameter, should be larger than :math:`V_{reset}`.
R 20 \ Membrane resistance.
tau 20 ms Membrane time constant. Compute by :math:`R * C`.
a 0 \ Coefficient describes the dependence of
:math:`V_{th}` on membrane potential.
b 0.01 \ Coefficient describes :math:`V_{th}` update.
k1 0.2 \ Constant pf :math:`I1`.
k2 0.02 \ Constant of :math:`I2`.
R1 0 \ Free parameter.
Describes dependence of :math:`I_1` reset value on
:math:`I_1` value before spiking.
R2 1 \ Free parameter.
Describes dependence of :math:`I_2` reset value on
:math:`I_2` value before spiking.
A1 0 \ Free parameter.
A2 0 \ Free parameter.
============= ============== ======== ====================================================================

**Model Variables**

================== ================= =========================================================
**Variables name** **Initial Value** **Explanation**
------------------ ----------------- ---------------------------------------------------------
V -70 Membrane potential.
input 0 External and synaptic input current.
spike False Flag to mark whether the neuron is spiking.
V_th -50 Spiking threshold potential.
I1 0 Internal current 1.
I2 0 Internal current 2.
t_last_spike -1e7 Last spike time stamp.
================== ================= =========================================================

**References**

.. [1] Mihalaş, Ştefan, and Ernst Niebur. "A generalized linear
integrate-and-fire neural model produces diverse spiking
behaviors." Neural computation 21.3 (2009): 704-718.
.. [2] Teeter, Corinne, Ramakrishnan Iyer, Vilas Menon, Nathan
Gouwens, David Feng, Jim Berg, Aaron Szafer et al. "Generalized
leaky integrate-and-fire models classify multiple neuron types."
Nature communications 9, no. 1 (2018): 1-15.
"""

def __init__(self, size, V_rest=-70., V_reset=-70., V_th_inf=-50., V_th_reset=-60.,
R=20., tau=20., a=0., b=0.01, k1=0.2, k2=0.02, R1=0., R2=1., A1=0.,
A2=0., method='exp_auto', name=None):
# initialization
super(GIF, self).__init__(size=size, name=name)

# params
self.V_rest = V_rest
self.V_reset = V_reset
self.V_th_inf = V_th_inf
self.V_th_reset = V_th_reset
self.R = R
self.tau = tau
self.a = a
self.b = b
self.k1 = k1
self.k2 = k2
self.R1 = R1
self.R2 = R2
self.A1 = A1
self.A2 = A2

# variables
self.I1 = bm.Variable(bm.zeros(self.num))
self.I2 = bm.Variable(bm.zeros(self.num))
self.V_th = bm.Variable(bm.ones(self.num) * -50.)
self.V = bm.Variable(bm.zeros(self.num))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)

# integral
self.integral = odeint(method=method, f=self.derivative)

def dI1(self, I1, t):
return - self.k1 * I1

def dI2(self, I2, t):
return - self.k2 * I2

def dVth(self, V_th, t, V):
return self.a * (V - self.V_rest) - self.b * (V_th - self.V_th_inf)

def dV(self, V, t, I1, I2, I_ext):
return (- (V - self.V_rest) + self.R * I_ext + self.R * I1 + self.R * I2) / self.tau

@property
def derivative(self):
return JointEq([self.dI1, self.dI2, self.dVth, self.dV])

def update(self, _t, _dt):
I1, I2, V_th, V = self.integral(self.I1, self.I2, self.V_th, self.V, _t, self.input, dt=_dt)
spike = self.V_th <= V
V = bm.where(spike, self.V_reset, V)
I1 = bm.where(spike, self.R1 * I1 + self.A1, I1)
I2 = bm.where(spike, self.R2 * I2 + self.A2, I2)
reset_th = bm.logical_and(V_th < self.V_th_reset, spike)
V_th = bm.where(reset_th, self.V_th_reset, V_th)
self.spike.value = spike
self.I1.value = I1
self.I2.value = I2
self.V_th.value = V_th
self.V.value = V
self.input[:] = 0.

+ 2
- 1
brainpy/dyn/neurons/__init__.py View File

@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-

from .biological_models import *
from .IF_models import *
from .fractional_models import *
from .input_models import *
from .noise_models import *
from .rate_models import *
from .reduced_models import *

+ 58
- 13
brainpy/dyn/neurons/biological_models.py View File

@@ -1,9 +1,14 @@
# -*- coding: utf-8 -*-

from typing import Union, Callable

import brainpy.math as bm
from brainpy.dyn.base import NeuGroup
from brainpy.initialize import OneInit, Uniform, Initializer, init_param
from brainpy.integrators.joint_eq import JointEq
from brainpy.integrators.ode import odeint
from brainpy.dyn.base import NeuGroup
from brainpy.tools.checking import check_initializer
from brainpy.types import Shape, Parameter, Tensor

__all__ = [
'HH',
@@ -178,8 +183,24 @@ class HH(NeuGroup):
The Journal of Mathematical Neuroscience 6, no. 1 (2016): 1-92.
"""

def __init__(self, size, ENa=50., gNa=120., EK=-77., gK=36., EL=-54.387, gL=0.03,
V_th=20., C=1.0, method='exp_auto', name=None):
def __init__(
self,
size: Shape,
ENa: Parameter = 50.,
gNa: Parameter = 120.,
EK: Parameter = -77.,
gK: Parameter = 36.,
EL: Parameter = -54.387,
gL: Parameter = 0.03,
V_th: Parameter = 20.,
C: Parameter = 1.0,
V_initializer: Union[Initializer, Callable, Tensor] = Uniform(-70, -60.),
m_initializer: Union[Initializer, Callable, Tensor] = OneInit(0.5),
h_initializer: Union[Initializer, Callable, Tensor] = OneInit(0.6),
n_initializer: Union[Initializer, Callable, Tensor] = OneInit(0.32),
method: str = 'exp_auto',
name: str = None
):
# initialization
super(HH, self).__init__(size=size, name=name)

@@ -194,10 +215,14 @@ class HH(NeuGroup):
self.V_th = V_th

# variables
self.m = bm.Variable(0.5 * bm.ones(self.num))
self.h = bm.Variable(0.6 * bm.ones(self.num))
self.n = bm.Variable(0.32 * bm.ones(self.num))
self.V = bm.Variable(bm.zeros(self.num))
check_initializer(m_initializer, 'm_initializer', allow_none=False)
check_initializer(h_initializer, 'h_initializer', allow_none=False)
check_initializer(n_initializer, 'n_initializer', allow_none=False)
check_initializer(V_initializer, 'V_initializer', allow_none=False)
self.m = bm.Variable(init_param(m_initializer, (self.num,)))
self.h = bm.Variable(init_param(h_initializer, (self.num,)))
self.n = bm.Variable(init_param(n_initializer, (self.num,)))
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)
@@ -334,11 +359,29 @@ class MorrisLecar(NeuGroup):
.. [3] https://en.wikipedia.org/wiki/Morris%E2%80%93Lecar_model
"""

def __init__(self, size, V_Ca=130., g_Ca=4.4, V_K=-84., g_K=8., V_leak=-60.,
g_leak=2., C=20., V1=-1.2, V2=18., V3=2., V4=30., phi=0.04,
V_th=10., method='exp_auto', name=None):
def __init__(
self,
size: Shape,
V_Ca: Parameter = 130.,
g_Ca: Parameter = 4.4,
V_K: Parameter = -84.,
g_K: Parameter = 8.,
V_leak: Parameter = -60.,
g_leak: Parameter = 2.,
C: Parameter = 20.,
V1: Parameter = -1.2,
V2: Parameter = 18.,
V3: Parameter = 2.,
V4: Parameter = 30.,
phi: Parameter = 0.04,
V_th: Parameter = 10.,
W_initializer: Union[Callable, Initializer, Tensor] = OneInit(0.02),
V_initializer: Union[Callable, Initializer, Tensor] = Uniform(-70., -60.),
method: str = 'exp_auto',
name: str = None
):
# initialization
super(MorrisLecar, self).__init__(size=size, name=name)
super(MorrisLecar, self).__init__(size=size, name=name)

# params
self.V_Ca = V_Ca
@@ -356,8 +399,10 @@ class MorrisLecar(NeuGroup):
self.V_th = V_th

# vars
self.W = bm.Variable(bm.ones(self.num) * 0.02)
self.V = bm.Variable(bm.zeros(self.num))
check_initializer(V_initializer, 'V_initializer', allow_none=False)
check_initializer(W_initializer, 'W_initializer', allow_none=False)
self.W = bm.Variable(init_param(W_initializer, (self.num,)))
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)


+ 294
- 0
brainpy/dyn/neurons/fractional_models.py View File

@@ -0,0 +1,294 @@
# -*- coding: utf-8 -*-

from typing import Union, Sequence, Callable

import brainpy.math as bm
from brainpy.dyn.base import NeuGroup
from brainpy.initialize import ZeroInit, OneInit, Initializer, init_param
from brainpy.integrators.fde import CaputoL1Schema
from brainpy.integrators.fde import GLShortMemory
from brainpy.integrators.joint_eq import JointEq
from brainpy.tools.checking import check_float, check_integer
from brainpy.tools.checking import check_initializer
from brainpy.types import Parameter, Shape, Tensor

__all__ = [
'FractionalNeuron',
'FractionalFHR',
'FractionalIzhikevich',
]


class FractionalNeuron(NeuGroup):
"""Fractional-order neuron model."""
pass


class FractionalFHR(FractionalNeuron):
r"""The fractional-order FH-R model [1]_.

FitzHugh and Rinzel introduced FH-R model (1976, in an unpublished article),
which is the modification of the classical FHN neuron model. The fractional-order
FH-R model is described as

.. math::

\begin{array}{rcl}
\frac{{d}^{\alpha }v}{d{t}^{\alpha }} & = & v-{v}^{3}/3-w+y+I={f}_{1}(v,w,y),\\
\frac{{d}^{\alpha }w}{d{t}^{\alpha }} & = & \delta (a+v-bw)={f}_{2}(v,w,y),\\
\frac{{d}^{\alpha }y}{d{t}^{\alpha }} & = & \mu (c-v-dy)={f}_{3}(v,w,y),
\end{array}

where :math:`v, w` and :math:`y` represent the membrane voltage, recovery variable
and slow modulation of the current respectively.
:math:`I` measures the constant magnitude of external stimulus current, and :math:`\alpha`
is the fractional exponent which ranges in the interval :math:`(0 < \alpha \le 1)`.
:math:`a, b, c, d, \delta` and :math:`\mu` are the system parameters.

The system reduces to the original classical order system when :math:`\alpha=1`.

:math:`\mu` indicates a small parameter that determines the pace of the slow system
variable :math:`y`. The fast subsystem (:math:`v-w`) presents a relaxation oscillator
in the phase plane where :math:`\delta` is a small parameter.
:math:`v` is expressed in mV (millivolt) scale. Time :math:`t` is in ms (millisecond) scale.
It exhibits tonic spiking or quiescent state depending on the parameter sets for a fixed
value of :math:`I`. The parameter :math:`a` in the 2D FHN model corresponds to the
parameter :math:`c` of the FH-R neuron model. If we decrease the value of :math:`a`,
it causes longer intervals between two burstings, however there exists :math:`a`
relatively fixed time of bursting duration. With the increasing of :math:`a`, the
interburst intervals become shorter and periodic bursting changes to tonic spiking.

Examples
--------

- [(Mondal, et, al., 2019): Fractional-order FitzHugh-Rinzel bursting neuron model](https://brainpy-examples.readthedocs.io/en/latest/neurons/2019_Fractional_order_FHR_model.html)


Parameters
----------
size: int, sequence of int
The size of the neuron group.
alpha: float, tensor
The fractional order.
num_memory: int
The total number of the short memory.

References
----------
.. [1] Mondal, A., Sharma, S.K., Upadhyay, R.K. *et al.* Firing activities of a fractional-order FitzHugh-Rinzel bursting neuron model and its coupled dynamics. *Sci Rep* **9,** 15721 (2019). https://doi.org/10.1038/s41598-019-52061-4
"""

def __init__(
self,
size: Shape,
alpha: Union[float, Sequence[float]],
num_memory: int = 1000,
a: Parameter = 0.7,
b: Parameter = 0.8,
c: Parameter = -0.775,
d: Parameter = 1.,
delta: Parameter = 0.08,
mu: Parameter = 0.0001,
Vth: Parameter = 1.8,
V_initializer: Union[Initializer, Callable, Tensor] = OneInit(2.5),
w_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
y_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
name: str = None
):
super(FractionalFHR, self).__init__(size, name=name)

# fractional order
self.alpha = alpha
check_integer(num_memory, 'num_memory', allow_none=False)

# parameters
self.a = a
self.b = b
self.c = c
self.d = d
self.delta = delta
self.mu = mu
self.Vth = Vth

# variables
check_initializer(V_initializer, 'V_initializer', allow_none=False)
check_initializer(w_initializer, 'w_initializer', allow_none=False)
check_initializer(y_initializer, 'y_initializer', allow_none=False)
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.w = bm.Variable(init_param(w_initializer, (self.num,)))
self.y = bm.Variable(init_param(y_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)

# integral function
self.integral = GLShortMemory(self.derivative,
alpha=alpha,
num_memory=num_memory,
inits=[self.V, self.w, self.y])

def dV(self, V, t, w, y):
return V - V ** 3 / 3 - w + y + self.input

def dw(self, w, t, V):
return self.delta * (self.a + V - self.b * w)

def dy(self, y, t, V):
return self.mu * (self.c - V - self.d * y)

@property
def derivative(self):
return JointEq([self.dV, self.dw, self.dy])

def update(self, _t, _dt):
V, w, y = self.integral(self.V, self.w, self.y, _t, _dt)
self.spike.value = bm.logical_and(V >= self.Vth, self.V < self.Vth)
self.t_last_spike.value = bm.where(self.spike, _t, self.t_last_spike)
self.V.value = V
self.w.value = w
self.y.value = y
self.input[:] = 0.

def set_init(self, values: dict):
for k, v in values.items():
if k not in self.integral.inits:
raise ValueError(f'Variable "{k}" is not defined in this model.')
variable = getattr(self, k)
variable[:] = v
self.integral.inits[k][:] = v


class FractionalIzhikevich(FractionalNeuron):
r"""Fractional-order Izhikevich model [10]_.

The fractional-order Izhikevich model is given by

.. math::

\begin{aligned}
&\tau \frac{d^{\alpha} v}{d t^{\alpha}}=\mathrm{f} v^{2}+g v+h-u+R I \\
&\tau \frac{d^{\alpha} u}{d t^{\alpha}}=a(b v-u)
\end{aligned}

where :math:`\alpha` is the fractional order (exponent) such that :math:`0<\alpha\le1`.
It is a commensurate system that reduces to classical Izhikevich model at :math:`\alpha=1`.

The time :math:`t` is in ms; and the system variable :math:`v` expressed in mV
corresponds to membrane voltage. Moreover, :math:`u` expressed in mV is the
recovery variable that corresponds to the activation of K+ ionic current and
inactivation of Na+ ionic current.

The parameters :math:`f, g, h` are fixed constants (should not be changed) such
that :math:`f=0.04` (mV)−1, :math:`g=5, h=140` mV; and :math:`a` and :math:`b` are
dimensionless parameters. The time constant :math:`\tau=1` ms; the resistance
:math:`R=1` Ω; and :math:`I` expressed in mA measures the injected (applied)
dc stimulus current to the system.

When the membrane voltage reaches the spike peak :math:`v_{peak}`, the two variables
are rest as follow:

.. math::

\text { if } v \geq v_{\text {peak }} \text { then }\left\{\begin{array}{l}
v \leftarrow c \\
u \leftarrow u+d
\end{array}\right.

we used :math:`v_{peak}=30` mV, and :math:`c` and :math:`d` are parameters expressed
in mV. When the spike reaches its peak value, the membrane voltage :math:`v` and the
recovery variable :math:`u` are reset according to the above condition.

Examples
--------

- [(Teka, et. al, 2018): Fractional-order Izhikevich neuron model](https://brainpy-examples.readthedocs.io/en/latest/neurons/2018_Fractional_Izhikevich_model.html)


References
----------
.. [10] Teka, Wondimu W., Ranjit Kumar Upadhyay, and Argha Mondal. "Spiking and
bursting patterns of fractional-order Izhikevich model." Communications
in Nonlinear Science and Numerical Simulation 56 (2018): 161-176.

"""

def __init__(
self,
size: Shape,
alpha: Union[float, Sequence[float]],
num_step: int,
a: Parameter = 0.02,
b: Parameter = 0.20,
c: Parameter = -65.,
d: Parameter = 8.,
f: Parameter = 0.04,
g: Parameter = 5.,
h: Parameter = 140.,
tau: Parameter = 1.,
R: Parameter = 1.,
V_th: Parameter = 30.,
V_initializer: Union[Initializer, Callable, Tensor] = OneInit(-65.),
u_initializer: Union[Initializer, Callable, Tensor] = OneInit(0.20 * -65.),
name: str = None
):
# initialization
super(FractionalIzhikevich, self).__init__(size=size, name=name)

# params
self.alpha = alpha
check_float(alpha, 'alpha', min_bound=0., max_bound=1., allow_none=False, allow_int=True)
self.a = a
self.b = b
self.c = c
self.d = d
self.f = f
self.g = g
self.h = h
self.tau = tau
self.R = R
self.V_th = V_th

# variables
check_initializer(V_initializer, 'V_initializer', allow_none=False)
check_initializer(u_initializer, 'u_initializer', allow_none=False)
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.u = bm.Variable(init_param(u_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)

# functions
check_integer(num_step, 'num_step', allow_none=False)
self.integral = CaputoL1Schema(f=self.derivative,
alpha=alpha,
num_step=num_step,
inits=[self.V, self.u])

def dV(self, V, t, u, I_ext):
dVdt = self.f * V * V + self.g * V + self.h - u + self.R * I_ext
return dVdt / self.tau

def du(self, u, t, V):
dudt = self.a * (self.b * V - u)
return dudt / self.tau

@property
def derivative(self):
return JointEq([self.dV, self.du])

def update(self, _t, _dt):
V, u = self.integral(self.V, self.u, t=_t, I_ext=self.input, dt=_dt)
spikes = V >= self.V_th
self.t_last_spike.value = bm.where(spikes, _t, self.t_last_spike)
self.V.value = bm.where(spikes, self.c, V)
self.u.value = bm.where(spikes, u + self.d, u)
self.spike.value = spikes
self.input[:] = 0.

def set_init(self, values: dict):
for k, v in values.items():
if k not in self.integral.inits:
raise ValueError(f'Variable "{k}" is not defined in this model.')
variable = getattr(self, k)
variable[:] = v
self.integral.inits[k][:] = v

+ 72
- 0
brainpy/dyn/neurons/noise_models.py View File

@@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-

import brainpy.math as bm
from brainpy.dyn.base import NeuGroup
from brainpy.integrators.sde import sdeint
from brainpy.types import Parameter, Shape

__all__ = [
'OUProcess',
]


class OUProcess(NeuGroup):
r"""The Ornstein–Uhlenbeck process.

The Ornstein–Uhlenbeck process :math:`x_{t}` is defined by the following
stochastic differential equation:

.. math::

\tau dx_{t}=-\theta \,x_{t}\,dt+\sigma \,dW_{t}

where :math:`\theta >0` and :math:`\sigma >0` are parameters and :math:`W_{t}`
denotes the Wiener process.

Parameters
----------
size: int, sequence of int
The model size.
mean: Parameter
The noise mean value.
sigma: Parameter
The noise amplitude.
tau: Parameter
The decay time constant.
method: str
The numerical integration method for stochastic differential equation.
name: str
The model name.
"""

def __init__(
self,
size: Shape,
mean: Parameter,
sigma: Parameter,
tau: Parameter,
method: str = 'euler',
name: str = None
):
super(OUProcess, self).__init__(size=size, name=name)

# parameters
self.mean = mean
self.sigma = sigma
self.tau = tau

# variables
self.x = bm.Variable(bm.ones(self.num) * mean)

# integral functions
self.integral = sdeint(f=self.df, g=self.dg, method=method)

def df(self, x, t):
f_x_ou = (self.mean - x) / self.tau
return f_x_ou

def dg(self, x, t):
return self.sigma

def update(self, _t, _dt):
self.x.value = self.integral(self.x, _t, _dt)

+ 542
- 180
brainpy/dyn/neurons/rate_models.py View File

@@ -1,145 +1,155 @@
# -*- coding: utf-8 -*-

from typing import Union, Callable

import numpy as np
from jax.experimental.host_callback import id_tap

import brainpy.math as bm
from brainpy import check
from brainpy.dyn.base import NeuGroup
from brainpy.initialize import Initializer, Uniform
from brainpy.initialize import init_param
from brainpy.integrators.dde import ddeint
from brainpy.integrators.joint_eq import JointEq
from brainpy.integrators.ode import odeint
from brainpy.types import Parameter, Shape
from brainpy.tools.checking import check_float, check_initializer
from brainpy.types import Parameter, Shape, Tensor
from .noise_models import OUProcess

__all__ = [
'FHN',
'RateGroup',
'RateFHN',
'FeedbackFHN',
'MeanFieldQIF',
'RateQIF',
'StuartLandauOscillator',
'WilsonCowanModel',
]


class FHN(NeuGroup):
r"""FitzHugh-Nagumo neuron model.

**Model Descriptions**

The FitzHugh–Nagumo model (FHN), named after Richard FitzHugh (1922–2007)
who suggested the system in 1961 [1]_ and J. Nagumo et al. who created the
equivalent circuit the following year, describes a prototype of an excitable
system (e.g., a neuron).
class RateGroup(NeuGroup):
def update(self, _t, _dt):
raise NotImplementedError

The motivation for the FitzHugh-Nagumo model was to isolate conceptually
the essentially mathematical properties of excitation and propagation from
the electrochemical properties of sodium and potassium ion flow. The model
consists of

- a *voltage-like variable* having cubic nonlinearity that allows regenerative
self-excitation via a positive feedback, and
- a *recovery variable* having a linear dynamics that provides a slower negative feedback.
class RateFHN(NeuGroup):
r"""FitzHugh-Nagumo system used in [1]_.

.. math::

\begin{aligned}
{\dot {v}} &=v-{\frac {v^{3}}{3}}-w+RI_{\rm {ext}}, \\
\tau {\dot {w}}&=v+a-bw.
\end{aligned}

The FHN Model is an example of a relaxation oscillator
because, if the external stimulus :math:`I_{\text{ext}}`
exceeds a certain threshold value, the system will exhibit
a characteristic excursion in phase space, before the
variables :math:`v` and :math:`w` relax back to their rest values.
This behaviour is typical for spike generations (a short,
nonlinear elevation of membrane voltage :math:`v`,
diminished over time by a slower, linear recovery variable
:math:`w`) in a neuron after stimulation by an external
input current.

**Model Examples**

.. plot::
:include-source: True

>>> import brainpy as bp
>>> fhn = bp.dyn.FHN(1)
>>> runner = bp.dyn.DSRunner(fhn, inputs=('input', 1.), monitors=['V', 'w'])
>>> runner.run(100.)
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.w, legend='w')
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V', show=True)

**Model Parameters**

============= ============== ======== ========================
**Parameter** **Init Value** **Unit** **Explanation**
------------- -------------- -------- ------------------------
a 1 \ Positive constant
b 1 \ Positive constant
tau 10 ms Membrane time constant.
V_th 1.8 mV Threshold potential of spike.
============= ============== ======== ========================

**Model Variables**
\frac{dx}{dt} = -\alpha V^3 + \beta V^2 + \gamma V - w + I_{ext}\\
\tau \frac{dy}{dt} = (V - \delta - \epsilon w)

================== ================= =========================================================
**Variables name** **Initial Value** **Explanation**
------------------ ----------------- ---------------------------------------------------------
V 0 Membrane potential.
w 0 A recovery variable which represents
the combined effects of sodium channel
de-inactivation and potassium channel
deactivation.
input 0 External and synaptic input current.
spike False Flag to mark whether the neuron is spiking.
t_last_spike -1e7 Last spike time stamp.
================== ================= =========================================================
Parameters
----------
size: Shape
The model size.
x_ou_mean: Parameter
The noise mean of the :math:`x` variable, [mV/ms]
y_ou_mean: Parameter
The noise mean of the :math:`y` variable, [mV/ms].
x_ou_sigma: Parameter
The noise intensity of the :math:`x` variable, [mV/ms/sqrt(ms)].
y_ou_sigma: Parameter
The noise intensity of the :math:`y` variable, [mV/ms/sqrt(ms)].
x_ou_tau: Parameter
The timescale of the Ornstein-Uhlenbeck noise process of :math:`x` variable, [ms].
y_ou_tau: Parameter
The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms].

**References**

.. [1] FitzHugh, Richard. "Impulses and physiological states in theoretical models of nerve membrane." Biophysical journal 1.6 (1961): 445-466.
.. [2] https://en.wikipedia.org/wiki/FitzHugh%E2%80%93Nagumo_model
.. [3] http://www.scholarpedia.org/article/FitzHugh-Nagumo_model
References
----------
.. [1] Kostova, T., Ravindran, R., & Schonbek, M. (2004). FitzHugh–Nagumo
revisited: Types of bifurcations, periodical forcing and stability
regions by a Lyapunov functional. International journal of
bifurcation and chaos, 14(03), 913-925.

"""

def __init__(self,
size: Shape,
a: Parameter = 0.7,
b: Parameter = 0.8,
tau: Parameter = 12.5,
Vth: Parameter = 1.8,
method: str = 'exp_auto',
name: str = None):
# initialization
super(FHN, self).__init__(size=size, name=name)

# parameters
self.a = a
self.b = b
def __init__(
self,
size: Shape,

# fhn parameters
alpha: Parameter = 3.0,
beta: Parameter = 4.0,
gamma: Parameter = -1.5,
delta: Parameter = 0.0,
epsilon: Parameter = 0.5,
tau: Parameter = 20.0,

# noise parameters
x_ou_mean: Parameter = 0.0,
x_ou_sigma: Parameter = 0.0,
x_ou_tau: Parameter = 5.0,
y_ou_mean: Parameter = 0.0,
y_ou_sigma: Parameter = 0.0,
y_ou_tau: Parameter = 5.0,

# other parameters
x_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.05),
y_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.05),
method: str = None,
sde_method: str = None,
name: str = None,
):
super(RateFHN, self).__init__(size=size, name=name)

# model parameters
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.delta = delta
self.epsilon = epsilon
self.tau = tau
self.Vth = Vth

# noise parameters
self.x_ou_mean = x_ou_mean # mV/ms, OU process
self.y_ou_mean = y_ou_mean # mV/ms, OU process
self.x_ou_sigma = x_ou_sigma # mV/ms/sqrt(ms), noise intensity
self.y_ou_sigma = y_ou_sigma # mV/ms/sqrt(ms), noise intensity
self.x_ou_tau = x_ou_tau # ms, timescale of the Ornstein-Uhlenbeck noise process
self.y_ou_tau = y_ou_tau # ms, timescale of the Ornstein-Uhlenbeck noise process

# variables
self.w = bm.Variable(bm.zeros(self.num))
self.V = bm.Variable(bm.zeros(self.num))
check_initializer(x_initializer, 'x_initializer')
check_initializer(y_initializer, 'y_initializer')
self.x = bm.Variable(init_param(x_initializer, (self.num,)))
self.y = bm.Variable(init_param(x_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)

# integral
self.integral = odeint(method=method, f=self.derivative)
# noise variables
self.x_ou = self.y_ou = None
if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.):
self.x_ou = OUProcess(self.num,
self.x_ou_mean, self.x_ou_sigma, self.x_ou_tau,
method=sde_method)
if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.):
self.y_ou = OUProcess(self.num,
self.y_ou_mean, self.y_ou_sigma, self.y_ou_tau,
method=sde_method)

def dV(self, V, t, w, I_ext):
return V - V * V * V / 3 - w + I_ext
# integral functions
self.integral = odeint(f=JointEq([self.dx, self.dy]), method=method)

def dw(self, w, t, V):
return (V + self.a - self.b * w) / self.tau
def dx(self, x, t, y, x_ext):
return - self.alpha * x ** 3 + self.beta * x ** 2 + self.gamma * x - y + x_ext

@property
def derivative(self):
return JointEq([self.dV, self.dw])
def dy(self, y, t, x, y_ext=0.):
return (x - self.delta - self.epsilon * y) / self.tau + y_ext

def update(self, _t, _dt):
V, w = self.integral(self.V, self.w, _t, self.input, dt=_dt)
self.spike.value = bm.logical_and(V >= self.Vth, self.V < self.Vth)
self.t_last_spike.value = bm.where(self.spike, _t, self.t_last_spike)
self.V.value = V
self.w.value = w
if self.x_ou is not None:
self.input += self.x_ou.x
self.x_ou.update(_t, _dt)
y_ext = 0.
if self.y_ou is not None:
y_ext = self.y_ou.x
self.y_ou.update(_t, _dt)
x, y = self.integral(self.x, self.y, _t, x_ext=self.input, y_ext=y_ext, dt=_dt)
self.x.value = x
self.y.value = y
self.input[:] = 0.


@@ -151,8 +161,8 @@ class FeedbackFHN(NeuGroup):
.. math::

\begin{aligned}
\frac{dv}{dt} &= v(t) - \frac{v^3(t)}{3} - w(t) + \mu[v(t-\mathrm{delay}) - v_0] \\
\frac{dw}{dt} &= [v(t) + a - b w(t)] / \tau
\frac{dx}{dt} &= x(t) - \frac{x^3(t)}{3} - y(t) + \mu[x(t-\mathrm{delay}) - x_0] \\
\frac{dy}{dt} &= [x(t) + a - b y(t)] / \tau
\end{aligned}


@@ -160,10 +170,10 @@ class FeedbackFHN(NeuGroup):

>>> import brainpy as bp
>>> fhn = bp.dyn.FeedbackFHN(1, delay=10.)
>>> runner = bp.dyn.DSRunner(fhn, inputs=('input', 1.), monitors=['V', 'w'])
>>> runner = bp.dyn.DSRunner(fhn, inputs=('input', 1.), monitors=['x', 'y'])
>>> runner.run(100.)
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.w, legend='w')
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V', show=True)
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.y, legend='y')
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.x, legend='x', show=True)


**Model Parameters**
@@ -181,6 +191,23 @@ class FeedbackFHN(NeuGroup):
when negative, it is a inhibitory feedback.
============= ============== ======== ========================

Parameters
----------
x_ou_mean: Parameter
The noise mean of the :math:`x` variable, [mV/ms]
y_ou_mean: Parameter
The noise mean of the :math:`y` variable, [mV/ms].
x_ou_sigma: Parameter
The noise intensity of the :math:`x` variable, [mV/ms/sqrt(ms)].
y_ou_sigma: Parameter
The noise intensity of the :math:`y` variable, [mV/ms/sqrt(ms)].
x_ou_tau: Parameter
The timescale of the Ornstein-Uhlenbeck noise process of :math:`x` variable, [ms].
y_ou_tau: Parameter
The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms].



References
----------
.. [4] Plant, Richard E. (1981). *A FitzHugh Differential-Difference
@@ -189,61 +216,109 @@ class FeedbackFHN(NeuGroup):

"""

def __init__(self,
size: Shape,
a: Parameter = 0.7,
b: Parameter = 0.8,
delay: Parameter = 10.,
tau: Parameter = 12.5,
mu: Parameter = 1.6886,
v0: Parameter = -1,
Vth: Parameter = 1.8,
method: str = 'rk4',
name: str = None):
def __init__(
self,
size: Shape,

# model parameters
a: Parameter = 0.7,
b: Parameter = 0.8,
delay: Parameter = 10.,
tau: Parameter = 12.5,
mu: Parameter = 1.6886,
x0: Parameter = -1,

# noise parameters
x_ou_mean: Parameter = 0.0,
x_ou_sigma: Parameter = 0.0,
x_ou_tau: Parameter = 5.0,
y_ou_mean: Parameter = 0.0,
y_ou_sigma: Parameter = 0.0,
y_ou_tau: Parameter = 5.0,

# other parameters
x_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.05),
y_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.05),
method: str = 'rk4',
sde_method: str = None,
name: str = None,
dt: float = None
):
super(FeedbackFHN, self).__init__(size=size, name=name)

# dt
self.dt = bm.get_dt() if dt is None else dt
check_float(self.dt, 'dt', allow_none=False, min_bound=0., allow_int=False)

# parameters
self.a = a
self.b = b
self.delay = delay
self.tau = tau
self.mu = mu # feedback strength
self.v0 = v0 # resting potential
self.Vth = Vth
self.v0 = x0 # resting potential

# noise parameters
self.x_ou_mean = x_ou_mean
self.y_ou_mean = y_ou_mean
self.x_ou_sigma = x_ou_sigma
self.y_ou_sigma = y_ou_sigma
self.x_ou_tau = x_ou_tau
self.y_ou_tau = y_ou_tau

# variables
self.w = bm.Variable(bm.zeros(self.num))
self.V = bm.Variable(bm.zeros(self.num))
self.Vdelay = bm.FixedLenDelay(self.num, self.delay)
check_initializer(x_initializer, 'x_initializer')
check_initializer(y_initializer, 'y_initializer')
self.x = bm.Variable(init_param(x_initializer, (self.num,)))
self.y = bm.Variable(init_param(x_initializer, (self.num,)))
self.x_delay = bm.TimeDelay(self.x, self.delay, dt=self.dt, interp_method='round')
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)

# noise variables
self.x_ou = self.y_ou = None
if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.):
self.x_ou = OUProcess(self.num,
self.x_ou_mean, self.x_ou_sigma, self.x_ou_tau,
method=sde_method)
if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.):
self.y_ou = OUProcess(self.num,
self.y_ou_mean, self.y_ou_sigma, self.y_ou_tau,
method=sde_method)

# integral
self.integral = ddeint(method=method, f=self.derivative,
state_delays={'V': self.Vdelay})
self.integral = ddeint(method=method,
f=JointEq([self.dx, self.dy]),
state_delays={'V': self.x_delay})

def dV(self, V, t, w, Vdelay):
return (V - V * V * V / 3 - w + self.input +
self.mu * (Vdelay(t - self.delay) - self.v0))
def dx(self, x, t, y, x_ext):
return x - x * x * x / 3 - y + x_ext + self.mu * (self.x_delay(t - self.delay) - self.v0)

def dw(self, w, t, V):
return (V + self.a - self.b * w) / self.tau
def dy(self, y, t, x, y_ext):
return (x + self.a - self.b * y + y_ext) / self.tau

@property
def derivative(self):
return JointEq([self.dV, self.dw])
def _check_dt(self, dt, *args):
if np.absolute(dt - self.dt) > 1e-6:
raise ValueError(f'The "dt" {dt} used in model running is '
f'not consistent with the "dt" {self.dt} '
f'used in model definition.')

def update(self, _t, _dt):
V, w = self.integral(self.V, self.w, _t, Vdelay=self.Vdelay, dt=_dt)
self.spike.value = bm.logical_and(V >= self.Vth, self.V < self.Vth)
self.t_last_spike.value = bm.where(self.spike, _t, self.t_last_spike)
self.V.value = V
self.w.value = w
if check.is_checking():
id_tap(self._check_dt, _dt)
if self.x_ou is not None:
self.input += self.x_ou.x
self.x_ou.update(_t, _dt)
y_ext = 0.
if self.y_ou is not None:
y_ext = self.y_ou.x
self.y_ou.update(_t, _dt)
x, y = self.integral(self.x, self.y, _t, x_ext=self.input, y_ext=y_ext, dt=_dt)
self.x.value = x
self.y.value = y
self.input[:] = 0.


class MeanFieldQIF(NeuGroup):
class RateQIF(NeuGroup):
r"""A mean-field model of a quadratic integrate-and-fire neuron population.

**Model Descriptions**
@@ -282,6 +357,21 @@ class MeanFieldQIF(NeuGroup):
J 15 \ the strength of the recurrent coupling inside the population
============= ============== ======== ========================

Parameters
----------
x_ou_mean: Parameter
The noise mean of the :math:`x` variable, [mV/ms]
y_ou_mean: Parameter
The noise mean of the :math:`y` variable, [mV/ms].
x_ou_sigma: Parameter
The noise intensity of the :math:`x` variable, [mV/ms/sqrt(ms)].
y_ou_sigma: Parameter
The noise intensity of the :math:`y` variable, [mV/ms/sqrt(ms)].
x_ou_tau: Parameter
The timescale of the Ornstein-Uhlenbeck noise process of :math:`x` variable, [ms].
y_ou_tau: Parameter
The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms].


References
----------
@@ -294,15 +384,32 @@ class MeanFieldQIF(NeuGroup):

"""

def __init__(self,
size: Shape,
tau: Parameter = 1.,
eta: Parameter = -5.0,
delta: Parameter = 1.0,
J: Parameter = 15.,
method: str = 'exp_auto',
name: str = None):
super(MeanFieldQIF, self).__init__(size=size, name=name)
def __init__(
self,
size: Shape,

# model parameters
tau: Parameter = 1.,
eta: Parameter = -5.0,
delta: Parameter = 1.0,
J: Parameter = 15.,

# noise parameters
x_ou_mean: Parameter = 0.0,
x_ou_sigma: Parameter = 0.0,
x_ou_tau: Parameter = 5.0,
y_ou_mean: Parameter = 0.0,
y_ou_sigma: Parameter = 0.0,
y_ou_tau: Parameter = 5.0,

# other parameters
x_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.05),
y_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.05),
method: str = 'exp_auto',
name: str = None,
sde_method: str = None,
):
super(RateQIF, self).__init__(size=size, name=name)

# parameters
self.tau = tau #
@@ -310,54 +417,309 @@ class MeanFieldQIF(NeuGroup):
self.delta = delta # the half-width at half maximum of the Lorenzian distribution over the neural excitability
self.J = J # the strength of the recurrent coupling inside the population

# noise parameters
self.x_ou_mean = x_ou_mean
self.y_ou_mean = y_ou_mean
self.x_ou_sigma = x_ou_sigma
self.y_ou_sigma = y_ou_sigma
self.x_ou_tau = x_ou_tau
self.y_ou_tau = y_ou_tau

# variables
self.r = bm.Variable(bm.ones(1))
self.V = bm.Variable(bm.ones(1))
self.input = bm.Variable(bm.zeros(1))
check_initializer(x_initializer, 'x_initializer')
check_initializer(y_initializer, 'y_initializer')
self.x = bm.Variable(init_param(x_initializer, (self.num,)))
self.y = bm.Variable(init_param(x_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))

# noise variables
self.x_ou = self.y_ou = None
if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.):
self.x_ou = OUProcess(self.num,
self.x_ou_mean, self.x_ou_sigma, self.x_ou_tau,
method=sde_method)
if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.):
self.y_ou = OUProcess(self.num,
self.y_ou_mean, self.y_ou_sigma, self.y_ou_tau,
method=sde_method)

# functions
self.integral = odeint(self.derivative, method=method)
self.integral = odeint(JointEq([self.dx, self.dy]), method=method)

def dy(self, y, t, x, y_ext):
return (self.delta / (bm.pi * self.tau) + 2. * x * y + y_ext) / self.tau

def dx(self, x, t, y, x_ext):
return (x ** 2 + self.eta + x_ext + self.J * y * self.tau -
(bm.pi * y * self.tau) ** 2) / self.tau

def update(self, _t, _dt):
if self.x_ou is not None:
self.input += self.x_ou.x
self.x_ou.update(_t, _dt)
y_ext = 0.
if self.y_ou is not None:
y_ext = self.y_ou.x
self.y_ou.update(_t, _dt)
x, y = self.integral(self.x, self.y, t=_t, x_ext=self.input, y_ext=y_ext, dt=_dt)
self.x.value = x
self.y.value = y
self.input[:] = 0.


class StuartLandauOscillator(RateGroup):
r"""
Stuart-Landau model with Hopf bifurcation.

.. math::

\frac{dx}{dt} = (a - x^2 - y^2) * x - w*y + I^x_{ext} \\
\frac{dy}{dt} = (a - x^2 - y^2) * y + w*x + I^y_{ext}

Parameters
----------
x_ou_mean: Parameter
The noise mean of the :math:`x` variable, [mV/ms]
y_ou_mean: Parameter
The noise mean of the :math:`y` variable, [mV/ms].
x_ou_sigma: Parameter
The noise intensity of the :math:`x` variable, [mV/ms/sqrt(ms)].
y_ou_sigma: Parameter
The noise intensity of the :math:`y` variable, [mV/ms/sqrt(ms)].
x_ou_tau: Parameter
The timescale of the Ornstein-Uhlenbeck noise process of :math:`x` variable, [ms].
y_ou_tau: Parameter
The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms].

"""

def __init__(
self,
size: Shape,

# model parameters
a=0.25,
w=0.2,

# noise parameters
x_ou_mean: Parameter = 0.0,
x_ou_sigma: Parameter = 0.0,
x_ou_tau: Parameter = 5.0,
y_ou_mean: Parameter = 0.0,
y_ou_sigma: Parameter = 0.0,
y_ou_tau: Parameter = 5.0,

# other parameters
x_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.5),
y_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.5),
method: str = None,
sde_method: str = None,
name: str = None,
):
super(StuartLandauOscillator, self).__init__(size=size,
name=name)

# model parameters
self.a = a
self.w = w

# noise parameters
self.x_ou_mean = x_ou_mean
self.y_ou_mean = y_ou_mean
self.x_ou_sigma = x_ou_sigma
self.y_ou_sigma = y_ou_sigma
self.x_ou_tau = x_ou_tau
self.y_ou_tau = y_ou_tau

# variables
check_initializer(x_initializer, 'x_initializer')
check_initializer(y_initializer, 'y_initializer')
self.x = bm.Variable(init_param(x_initializer, (self.num,)))
self.y = bm.Variable(init_param(x_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))

# noise variables
self.x_ou = self.y_ou = None
if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.):
self.x_ou = OUProcess(self.num,
self.x_ou_mean, self.x_ou_sigma, self.x_ou_tau,
method=sde_method)
if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.):
self.y_ou = OUProcess(self.num,
self.y_ou_mean, self.y_ou_sigma, self.y_ou_tau,
method=sde_method)

def dr(self, r, t, v):
return (self.delta / (bm.pi * self.tau) + 2. * r * v) / self.tau
# integral functions
self.integral = odeint(f=JointEq([self.dx, self.dy]), method=method)

def dV(self, v, t, r):
return (v ** 2 + self.eta + self.input + self.J * r * self.tau -
(bm.pi * r * self.tau) ** 2) / self.tau
def dx(self, x, t, y, x_ext, a, w):
return (a - x * x - y * y) * x - w * y + x_ext

@property
def derivative(self):
return JointEq([self.dV, self.dr])
def dy(self, y, t, x, y_ext, a, w):
return (a - x * x - y * y) * y - w * y + y_ext

def update(self, _t, _dt):
self.V.value, self.r.value = self.integral(self.V, self.r, _t, _dt)
self.integral[:] = 0.
if self.x_ou is not None:
self.input += self.x_ou.x
self.x_ou.update(_t, _dt)
y_ext = 0.
if self.y_ou is not None:
y_ext = self.y_ou.x
self.y_ou.update(_t, _dt)
x, y = self.integral(self.x, self.y, _t, x_ext=self.input,
y_ext=y_ext, a=self.a, w=self.w, dt=_dt)
self.x.value = x
self.y.value = y
self.input[:] = 0.


class WilsonCowanModel(RateGroup):
"""Wilson-Cowan population model.

class VanDerPolOscillator(NeuGroup):
pass

Parameters
----------
x_ou_mean: Parameter
The noise mean of the :math:`x` variable, [mV/ms]
y_ou_mean: Parameter
The noise mean of the :math:`y` variable, [mV/ms].
x_ou_sigma: Parameter
The noise intensity of the :math:`x` variable, [mV/ms/sqrt(ms)].
y_ou_sigma: Parameter
The noise intensity of the :math:`y` variable, [mV/ms/sqrt(ms)].
x_ou_tau: Parameter
The timescale of the Ornstein-Uhlenbeck noise process of :math:`x` variable, [ms].
y_ou_tau: Parameter
The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms].

class ThetaNeuron(NeuGroup):
pass

"""

class MeanFieldQIFWithSFA(NeuGroup):
pass
def __init__(
self,
size: Shape,

# Excitatory parameters
E_tau=1., # excitatory time constant
E_a=1.2, # excitatory gain
E_theta=2.8, # excitatory firing threshold

# Inhibitory parameters
I_tau=1., # inhibitory time constant
I_a=1., # inhibitory gain
I_theta=4.0, # inhibitory firing threshold

# connection parameters
wEE=12., # local E-E coupling
wIE=4., # local E-I coupling
wEI=13., # local I-E coupling
wII=11., # local I-I coupling

# Refractory parameter
r=1,

# state initializer
x_initializer: Union[Initializer, Callable, Tensor] = Uniform(max_val=0.05),
y_initializer: Union[Initializer, Callable, Tensor] = Uniform(max_val=0.05),

# noise parameters
x_ou_mean: Parameter = 0.0,
x_ou_sigma: Parameter = 0.0,
x_ou_tau: Parameter = 5.0,
y_ou_mean: Parameter = 0.0,
y_ou_sigma: Parameter = 0.0,
y_ou_tau: Parameter = 5.0,

# other parameters
sde_method: str = None,
method: str = 'exp_euler_auto',
name: str = None,
):
super(WilsonCowanModel, self).__init__(size=size, name=name)

# model parameters
self.E_tau = E_tau
self.E_a = E_a
self.E_theta = E_theta
self.I_tau = I_tau
self.I_a = I_a
self.I_theta = I_theta
self.wEE = wEE
self.wIE = wIE
self.wEI = wEI
self.wII = wII
self.r = r

# noise parameters
self.x_ou_mean = x_ou_mean
self.y_ou_mean = y_ou_mean
self.x_ou_sigma = x_ou_sigma
self.y_ou_sigma = y_ou_sigma
self.x_ou_tau = x_ou_tau
self.y_ou_tau = y_ou_tau

# variables
check_initializer(x_initializer, 'x_initializer')
check_initializer(y_initializer, 'y_initializer')
self.x = bm.Variable(init_param(x_initializer, (self.num,)))
self.y = bm.Variable(init_param(x_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))

# noise variables
self.x_ou = self.y_ou = None
if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.):
self.x_ou = OUProcess(self.num,
self.x_ou_mean, self.x_ou_sigma, self.x_ou_tau,
method=sde_method)
if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.):
self.y_ou = OUProcess(self.num,
self.y_ou_mean, self.y_ou_sigma, self.y_ou_tau,
method=sde_method)

# functions
self.integral = odeint(f=JointEq([self.dx, self.dy]), method=method)

class JansenRitModel(NeuGroup):
# functions
def F(self, x, a, theta):
return 1 / (1 + bm.exp(-a * (x - theta))) - 1 / (1 + bm.exp(a * theta))

def dx(self, x, t, y, x_ext):
x = self.wEE * x - self.wIE * y + x_ext
return (-x + (1 - self.r * x) * self.F(x, self.E_a, self.E_theta)) / self.E_tau

def dy(self, y, t, x, y_ext):
x = self.wEI * x - self.wII * y + y_ext
return (-y + (1 - self.r * y) * self.F(x, self.I_a, self.I_theta)) / self.I_tau

def update(self, _t, _dt):
if self.x_ou is not None:
self.input += self.x_ou.x
self.x_ou.update(_t, _dt)
y_ext = 0.
if self.y_ou is not None:
y_ext = self.y_ou.x
self.y_ou.update(_t, _dt)
x, y = self.integral(self.x, self.y, _t, x_ext=self.input, y_ext=y_ext, dt=_dt)
self.x.value = x
self.y.value = y
self.input[:] = 0.


class JansenRitModel(RateGroup):
pass


class WilsonCowanModel(NeuGroup):
class KuramotoOscillator(RateGroup):
pass

class StuartLandauOscillator(NeuGroup):

class ThetaNeuron(RateGroup):
pass


class KuramotoOscillator(NeuGroup):
class RateQIFWithSFA(RateGroup):
pass


class VanDerPolOscillator(RateGroup):
pass

+ 1024
- 12
brainpy/dyn/neurons/reduced_models.py View File

@@ -1,16 +1,861 @@
# -*- coding: utf-8 -*-

from typing import Union, Callable

import brainpy.math as bm
from brainpy.dyn.base import NeuGroup
from brainpy.initialize import ZeroInit, OneInit, Initializer, init_param
from brainpy.integrators.joint_eq import JointEq
from brainpy.integrators.ode import odeint
from brainpy.dyn.base import NeuGroup
from brainpy.tools.checking import check_initializer
from brainpy.types import Shape, Parameter, Tensor

__all__ = [
'LIF',
'ExpIF',
'AdExIF',
'QuaIF',
'AdQuaIF',
'GIF',
'Izhikevich',
'HindmarshRose',
'FHN',
]


class LIF(NeuGroup):
r"""Leaky integrate-and-fire neuron model.

**Model Descriptions**

The formal equations of a LIF model [1]_ is given by:

.. math::

\tau \frac{dV}{dt} = - (V(t) - V_{rest}) + I(t) \\
\text{after} \quad V(t) \gt V_{th}, V(t) = V_{reset} \quad
\text{last} \quad \tau_{ref} \quad \text{ms}

where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting
membrane potential, :math:`V_{reset}` is the reset membrane potential,
:math:`V_{th}` is the spike threshold, :math:`\tau` is the time constant,
:math:`\tau_{ref}` is the refractory time period,
and :math:`I` is the time-variant synaptic inputs.

**Model Examples**

- `(Brette, Romain. 2004) LIF phase locking <https://brainpy-examples.readthedocs.io/en/latest/neurons/Romain_2004_LIF_phase_locking.html>`_

**Model Parameters**

============= ============== ======== =========================================
**Parameter** **Init Value** **Unit** **Explanation**
------------- -------------- -------- -----------------------------------------
V_rest 0 mV Resting membrane potential.
V_reset -5 mV Reset potential after spike.
V_th 20 mV Threshold potential of spike.
tau 10 ms Membrane time constant. Compute by R * C.
tau_ref 5 ms Refractory period length.(ms)
============= ============== ======== =========================================

**Neuron Variables**

================== ================= =========================================================
**Variables name** **Initial Value** **Explanation**
------------------ ----------------- ---------------------------------------------------------
V 0 Membrane potential.
input 0 External and synaptic input current.
spike False Flag to mark whether the neuron is spiking.
refractory False Flag to mark whether the neuron is in refractory period.
t_last_spike -1e7 Last spike time stamp.
================== ================= =========================================================

**References**

.. [1] Abbott, Larry F. "Lapicque’s introduction of the integrate-and-fire model
neuron (1907)." Brain research bulletin 50, no. 5-6 (1999): 303-304.
"""

def __init__(
self,
size: Shape,
V_rest: Parameter = 0.,
V_reset: Parameter = -5.,
V_th: Parameter = 20.,
tau: Parameter = 10.,
tau_ref: Parameter = 1.,
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
method: str = 'exp_auto',
name: str = None
):
# initialization
super(LIF, self).__init__(size=size, name=name)

# parameters
self.V_rest = V_rest
self.V_reset = V_reset
self.V_th = V_th
self.tau = tau
self.tau_ref = tau_ref

# variables
check_initializer(V_initializer, 'V_initializer')
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))

# integral
self.integral = odeint(method=method, f=self.derivative)

def derivative(self, V, t, I_ext):
dvdt = (-V + self.V_rest + I_ext) / self.tau
return dvdt

def update(self, _t, _dt):
refractory = (_t - self.t_last_spike) <= self.tau_ref
V = self.integral(self.V, _t, self.input, dt=_dt)
V = bm.where(refractory, self.V, V)
spike = V >= self.V_th
self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike)
self.V.value = bm.where(spike, self.V_reset, V)
self.refractory.value = bm.logical_or(refractory, spike)
self.spike.value = spike
self.input[:] = 0.


class ExpIF(NeuGroup):
r"""Exponential integrate-and-fire neuron model.

**Model Descriptions**

In the exponential integrate-and-fire model [1]_, the differential
equation for the membrane potential is given by

.. math::

\tau\frac{d V}{d t}= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} + RI(t), \\
\text{after} \, V(t) \gt V_{th}, V(t) = V_{reset} \, \text{last} \, \tau_{ref} \, \text{ms}

This equation has an exponential nonlinearity with "sharpness" parameter :math:`\Delta_{T}`
and "threshold" :math:`\vartheta_{rh}`.

The moment when the membrane potential reaches the numerical threshold :math:`V_{th}`
defines the firing time :math:`t^{(f)}`. After firing, the membrane potential is reset to
:math:`V_{rest}` and integration restarts at time :math:`t^{(f)}+\tau_{\rm ref}`,
where :math:`\tau_{\rm ref}` is an absolute refractory time.
If the numerical threshold is chosen sufficiently high, :math:`V_{th}\gg v+\Delta_T`,
its exact value does not play any role. The reason is that the upswing of the action
potential for :math:`v\gg v +\Delta_{T}` is so rapid, that it goes to infinity in
an incredibly short time. The threshold :math:`V_{th}` is introduced mainly for numerical
convenience. For a formal mathematical analysis of the model, the threshold can be pushed
to infinity.

The model was first introduced by Nicolas Fourcaud-Trocmé, David Hansel, Carl van Vreeswijk
and Nicolas Brunel [1]_. The exponential nonlinearity was later confirmed by Badel et al. [3]_.
It is one of the prominent examples of a precise theoretical prediction in computational
neuroscience that was later confirmed by experimental neuroscience.

Two important remarks:

- (i) The right-hand side of the above equation contains a nonlinearity
that can be directly extracted from experimental data [3]_. In this sense the exponential
nonlinearity is not an arbitrary choice but directly supported by experimental evidence.
- (ii) Even though it is a nonlinear model, it is simple enough to calculate the firing
rate for constant input, and the linear response to fluctuations, even in the presence
of input noise [4]_.

**Model Examples**

.. plot::
:include-source: True

>>> import brainpy as bp
>>> group = bp.dyn.ExpIF(1)
>>> runner = bp.dyn.DSRunner(group, monitors=['V'], inputs=('input', 10.))
>>> runner.run(300., )
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V', show=True)


**Model Parameters**

============= ============== ======== ===================================================
**Parameter** **Init Value** **Unit** **Explanation**
------------- -------------- -------- ---------------------------------------------------
V_rest -65 mV Resting potential.
V_reset -68 mV Reset potential after spike.
V_th -30 mV Threshold potential of spike.
V_T -59.9 mV Threshold potential of generating action potential.
delta_T 3.48 \ Spike slope factor.
R 1 \ Membrane resistance.
tau 10 \ Membrane time constant. Compute by R * C.
tau_ref 1.7 \ Refractory period length.
============= ============== ======== ===================================================

**Model Variables**

================== ================= =========================================================
**Variables name** **Initial Value** **Explanation**
------------------ ----------------- ---------------------------------------------------------
V 0 Membrane potential.
input 0 External and synaptic input current.
spike False Flag to mark whether the neuron is spiking.
refractory False Flag to mark whether the neuron is in refractory period.
t_last_spike -1e7 Last spike time stamp.
================== ================= =========================================================

**References**

.. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation
mechanisms determine the neuronal response to fluctuating
inputs." Journal of Neuroscience 23.37 (2003): 11628-11640.
.. [2] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014).
Neuronal dynamics: From single neurons to networks and models
of cognition. Cambridge University Press.
.. [3] Badel, Laurent, Sandrine Lefort, Romain Brette, Carl CH Petersen,
Wulfram Gerstner, and Magnus JE Richardson. "Dynamic IV curves
are reliable predictors of naturalistic pyramidal-neuron voltage
traces." Journal of Neurophysiology 99, no. 2 (2008): 656-666.
.. [4] Richardson, Magnus JE. "Firing-rate response of linear and nonlinear
integrate-and-fire neurons to modulated current-based and
conductance-based synaptic drive." Physical Review E 76, no. 2 (2007): 021919.
.. [5] https://en.wikipedia.org/wiki/Exponential_integrate-and-fire
"""

def __init__(
self,
size: Shape,
V_rest: Parameter = -65.,
V_reset: Parameter = -68.,
V_th: Parameter = -30.,
V_T: Parameter = -59.9,
delta_T: Parameter = 3.48,
R: Parameter = 1.,
tau: Parameter = 10.,
tau_ref: Parameter = 1.7,
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
method: str = 'exp_auto',
name: str = None
):
# initialize
super(ExpIF, self).__init__(size=size, name=name)

# parameters
self.V_rest = V_rest
self.V_reset = V_reset
self.V_th = V_th
self.V_T = V_T
self.delta_T = delta_T
self.R = R
self.tau = tau
self.tau_ref = tau_ref

# variables
check_initializer(V_initializer, 'V_initializer')
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)

# integral
self.integral = odeint(method=method, f=self.derivative)

def derivative(self, V, t, I_ext):
exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T)
dvdt = (- (V - self.V_rest) + exp_v + self.R * I_ext) / self.tau
return dvdt

def update(self, _t, _dt):
refractory = (_t - self.t_last_spike) <= self.tau_ref
V = self.integral(self.V, _t, self.input, dt=_dt)
V = bm.where(refractory, self.V, V)
spike = self.V_th <= V
self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike)
self.V.value = bm.where(spike, self.V_reset, V)
self.refractory.value = bm.logical_or(refractory, spike)
self.spike.value = spike
self.input[:] = 0.


class AdExIF(NeuGroup):
r"""Adaptive exponential integrate-and-fire neuron model.

**Model Descriptions**

The **adaptive exponential integrate-and-fire model**, also called AdEx, is a
spiking neuron model with two variables [1]_ [2]_.

.. math::

\begin{aligned}
\tau_m\frac{d V}{d t} &= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} - Rw + RI(t), \\
\tau_w \frac{d w}{d t} &=a(V-V_{rest}) - w
\end{aligned}

once the membrane potential reaches the spike threshold,

.. math::

V \rightarrow V_{reset}, \\
w \rightarrow w+b.

The first equation describes the dynamics of the membrane potential and includes
an activation term with an exponential voltage dependence. Voltage is coupled to
a second equation which describes adaptation. Both variables are reset if an action
potential has been triggered. The combination of adaptation and exponential voltage
dependence gives rise to the name Adaptive Exponential Integrate-and-Fire model.

The adaptive exponential integrate-and-fire model is capable of describing known
neuronal firing patterns, e.g., adapting, bursting, delayed spike initiation,
initial bursting, fast spiking, and regular spiking.

**Model Examples**

- `Examples for different firing patterns <https://brainpy-examples.readthedocs.io/en/latest/neurons/Gerstner_2005_AdExIF_model.html>`_

**Model Parameters**

============= ============== ======== ========================================================================================================================
**Parameter** **Init Value** **Unit** **Explanation**
------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------
V_rest -65 mV Resting potential.
V_reset -68 mV Reset potential after spike.
V_th -30 mV Threshold potential of spike and reset.
V_T -59.9 mV Threshold potential of generating action potential.
delta_T 3.48 \ Spike slope factor.
a 1 \ The sensitivity of the recovery variable :math:`u` to the sub-threshold fluctuations of the membrane potential :math:`v`
b 1 \ The increment of :math:`w` produced by a spike.
R 1 \ Membrane resistance.
tau 10 ms Membrane time constant. Compute by R * C.
tau_w 30 ms Time constant of the adaptation current.
============= ============== ======== ========================================================================================================================

**Model Variables**

================== ================= =========================================================
**Variables name** **Initial Value** **Explanation**
------------------ ----------------- ---------------------------------------------------------
V 0 Membrane potential.
w 0 Adaptation current.
input 0 External and synaptic input current.
spike False Flag to mark whether the neuron is spiking.
t_last_spike -1e7 Last spike time stamp.
================== ================= =========================================================

**References**

.. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation
mechanisms determine the neuronal response to fluctuating
inputs." Journal of Neuroscience 23.37 (2003): 11628-11640.
.. [2] http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model
"""

def __init__(
self,
size: Shape,
V_rest: Parameter = -65.,
V_reset: Parameter = -68.,
V_th: Parameter = -30.,
V_T: Parameter = -59.9,
delta_T: Parameter = 3.48,
a: Parameter = 1.,
b: Parameter = 1.,
tau: Parameter = 10.,
tau_w: Parameter = 30.,
R: Parameter = 1.,
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
w_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
method: str = 'exp_auto',
name: str = None
):
super(AdExIF, self).__init__(size=size, name=name)

# parameters
self.V_rest = V_rest
self.V_reset = V_reset
self.V_th = V_th
self.V_T = V_T
self.delta_T = delta_T
self.a = a
self.b = b
self.tau = tau
self.tau_w = tau_w
self.R = R

# variables
check_initializer(V_initializer, 'V_initializer')
check_initializer(w_initializer, 'w_initializer')
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.w = bm.Variable(init_param(w_initializer, (self.num,)))
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)

# functions
self.integral = odeint(method=method, f=self.derivative)

def dV(self, V, t, w, I_ext):
dVdt = (- V + self.V_rest + self.delta_T * bm.exp((V - self.V_T) / self.delta_T) -
self.R * w + self.R * I_ext) / self.tau
return dVdt

def dw(self, w, t, V):
dwdt = (self.a * (V - self.V_rest) - w) / self.tau_w
return dwdt

@property
def derivative(self):
return JointEq([self.dV, self.dw])

def update(self, _t, _dt):
V, w = self.integral(self.V, self.w, _t, self.input, dt=_dt)
spike = V >= self.V_th
self.t_last_spike[:] = bm.where(spike, _t, self.t_last_spike)
self.V.value = bm.where(spike, self.V_reset, V)
self.w.value = bm.where(spike, w + self.b, w)
self.spike.value = spike
self.input[:] = 0.


class QuaIF(NeuGroup):
r"""Quadratic Integrate-and-Fire neuron model.

**Model Descriptions**

In contrast to physiologically accurate but computationally expensive
neuron models like the Hodgkin–Huxley model, the QIF model [1]_ seeks only
to produce **action potential-like patterns** and ignores subtleties
like gating variables, which play an important role in generating action
potentials in a real neuron. However, the QIF model is incredibly easy
to implement and compute, and relatively straightforward to study and
understand, thus has found ubiquitous use in computational neuroscience.

.. math::

\tau \frac{d V}{d t}=c(V-V_{rest})(V-V_c) + RI(t)

where the parameters are taken to be :math:`c` =0.07, and :math:`V_c = -50 mV` (Latham et al., 2000).

**Model Examples**

.. plot::
:include-source: True

>>> import brainpy as bp
>>>
>>> group = bp.dyn.QuaIF(1,)
>>>
>>> runner = bp.dyn.DSRunner(group, monitors=['V'], inputs=('input', 20.))
>>> runner.run(duration=200.)
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)


**Model Parameters**

============= ============== ======== ========================================================================================================================
**Parameter** **Init Value** **Unit** **Explanation**
------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------
V_rest -65 mV Resting potential.
V_reset -68 mV Reset potential after spike.
V_th -30 mV Threshold potential of spike and reset.
V_c -50 mV Critical voltage for spike initiation. Must be larger than V_rest.
c .07 \ Coefficient describes membrane potential update. Larger than 0.
R 1 \ Membrane resistance.
tau 10 ms Membrane time constant. Compute by R * C.
tau_ref 0 ms Refractory period length.
============= ============== ======== ========================================================================================================================

**Model Variables**

================== ================= =========================================================
**Variables name** **Initial Value** **Explanation**
------------------ ----------------- ---------------------------------------------------------
V 0 Membrane potential.
input 0 External and synaptic input current.
spike False Flag to mark whether the neuron is spiking.
refractory False Flag to mark whether the neuron is in refractory period.
t_last_spike -1e7 Last spike time stamp.
================== ================= =========================================================

**References**

.. [1] P. E. Latham, B.J. Richmond, P. Nelson and S. Nirenberg
(2000) Intrinsic dynamics in neuronal networks. I. Theory.
J. Neurophysiology 83, pp. 808–827.
"""

def __init__(
self,
size: Shape,
V_rest: Parameter = -65.,
V_reset: Parameter = -68.,
V_th: Parameter = -30.,
V_c: Parameter = -50.0,
c: Parameter = .07,
R: Parameter = 1.,
tau: Parameter = 10.,
tau_ref: Parameter = 0.,
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
method: str = 'exp_auto',
name: str = None
):
# initialization
super(QuaIF, self).__init__(size=size, name=name)

# parameters
self.V_rest = V_rest
self.V_reset = V_reset
self.V_th = V_th
self.V_c = V_c
self.c = c
self.R = R
self.tau = tau
self.tau_ref = tau_ref

# variables
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)

# integral
self.integral = odeint(method=method, f=self.derivative)

def derivative(self, V, t, I_ext):
dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I_ext) / self.tau
return dVdt

def update(self, _t, _dt, **kwargs):
refractory = (_t - self.t_last_spike) <= self.tau_ref
V = self.integral(self.V, _t, self.input, dt=_dt)
V = bm.where(refractory, self.V, V)
spike = self.V_th <= V
self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike)
self.V.value = bm.where(spike, self.V_reset, V)
self.refractory.value = bm.logical_or(refractory, spike)
self.spike.value = spike
self.input[:] = 0.


class AdQuaIF(NeuGroup):
r"""Adaptive quadratic integrate-and-fire neuron model.

**Model Descriptions**

The adaptive quadratic integrate-and-fire neuron model [1]_ is given by:

.. math::

\begin{aligned}
\tau_m \frac{d V}{d t}&=c(V-V_{rest})(V-V_c) - w + I(t), \\
\tau_w \frac{d w}{d t}&=a(V-V_{rest}) - w,
\end{aligned}

once the membrane potential reaches the spike threshold,

.. math::

V \rightarrow V_{reset}, \\
w \rightarrow w+b.

**Model Examples**

.. plot::
:include-source: True

>>> import brainpy as bp
>>> group = bp.dyn.AdQuaIF(1, )
>>> runner = bp.dyn.DSRunner(group, monitors=['V', 'w'], inputs=('input', 30.))
>>> runner.run(300)
>>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
>>> fig.add_subplot(gs[0, 0])
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V')
>>> fig.add_subplot(gs[1, 0])
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.w, ylabel='w', show=True)

**Model Parameters**

============= ============== ======== =======================================================
**Parameter** **Init Value** **Unit** **Explanation**
------------- -------------- -------- -------------------------------------------------------
V_rest -65 mV Resting potential.
V_reset -68 mV Reset potential after spike.
V_th -30 mV Threshold potential of spike and reset.
V_c -50 mV Critical voltage for spike initiation. Must be larger
than :math:`V_{rest}`.
a 1 \ The sensitivity of the recovery variable :math:`u` to
the sub-threshold fluctuations of the membrane
potential :math:`v`
b .1 \ The increment of :math:`w` produced by a spike.
c .07 \ Coefficient describes membrane potential update.
Larger than 0.
tau 10 ms Membrane time constant.
tau_w 10 ms Time constant of the adaptation current.
============= ============== ======== =======================================================

**Model Variables**

================== ================= ==========================================================
**Variables name** **Initial Value** **Explanation**
------------------ ----------------- ----------------------------------------------------------
V 0 Membrane potential.
w 0 Adaptation current.
input 0 External and synaptic input current.
spike False Flag to mark whether the neuron is spiking.
t_last_spike -1e7 Last spike time stamp.
================== ================= ==========================================================

**References**

.. [1] Izhikevich, E. M. (2004). Which model to use for cortical spiking
neurons?. IEEE transactions on neural networks, 15(5), 1063-1070.
.. [2] Touboul, Jonathan. "Bifurcation analysis of a general class of
nonlinear integrate-and-fire neurons." SIAM Journal on Applied
Mathematics 68, no. 4 (2008): 1045-1079.
"""

def __init__(
self,
size: Shape,
V_rest: Parameter = -65.,
V_reset: Parameter = -68.,
V_th: Parameter = -30.,
V_c: Parameter = -50.0,
a: Parameter = 1.,
b: Parameter = .1,
c: Parameter = .07,
tau: Parameter = 10.,
tau_w: Parameter = 10.,
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
w_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
method: str = 'exp_auto',
name: str = None
):
super(AdQuaIF, self).__init__(size=size, name=name)

# parameters
self.V_rest = V_rest
self.V_reset = V_reset
self.V_th = V_th
self.V_c = V_c
self.c = c
self.a = a
self.b = b
self.tau = tau
self.tau_w = tau_w

# variables
check_initializer(V_initializer, 'V_initializer')
check_initializer(w_initializer, 'w_initializer')
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.w = bm.Variable(init_param(w_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))

# integral
self.integral = odeint(method=method, f=self.derivative)

def dV(self, V, t, w, I_ext):
dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + I_ext) / self.tau
return dVdt

def dw(self, w, t, V):
dwdt = (self.a * (V - self.V_rest) - w) / self.tau_w
return dwdt

@property
def derivative(self):
return JointEq([self.dV, self.dw])

def update(self, _t, _dt):
V, w = self.integral(self.V, self.w, _t, self.input, dt=_dt)
spike = self.V_th <= V
self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike)
self.V.value = bm.where(spike, self.V_reset, V)
self.w.value = bm.where(spike, w + self.b, w)
self.spike.value = spike
self.input[:] = 0.


class GIF(NeuGroup):
r"""Generalized Integrate-and-Fire model.

**Model Descriptions**

The generalized integrate-and-fire model [1]_ is given by

.. math::

&\frac{d I_j}{d t} = - k_j I_j

&\frac{d V}{d t} = ( - (V - V_{rest}) + R\sum_{j}I_j + RI) / \tau

&\frac{d V_{th}}{d t} = a(V - V_{rest}) - b(V_{th} - V_{th\infty})

When :math:`V` meet :math:`V_{th}`, Generalized IF neuron fires:

.. math::

&I_j \leftarrow R_j I_j + A_j

&V \leftarrow V_{reset}

&V_{th} \leftarrow max(V_{th_{reset}}, V_{th})

Note that :math:`I_j` refers to arbitrary number of internal currents.

**Model Examples**

- `Detailed examples to reproduce different firing patterns <https://brainpy-examples.readthedocs.io/en/latest/neurons/Niebur_2009_GIF.html>`_

**Model Parameters**

============= ============== ======== ====================================================================
**Parameter** **Init Value** **Unit** **Explanation**
------------- -------------- -------- --------------------------------------------------------------------
V_rest -70 mV Resting potential.
V_reset -70 mV Reset potential after spike.
V_th_inf -50 mV Target value of threshold potential :math:`V_{th}` updating.
V_th_reset -60 mV Free parameter, should be larger than :math:`V_{reset}`.
R 20 \ Membrane resistance.
tau 20 ms Membrane time constant. Compute by :math:`R * C`.
a 0 \ Coefficient describes the dependence of
:math:`V_{th}` on membrane potential.
b 0.01 \ Coefficient describes :math:`V_{th}` update.
k1 0.2 \ Constant pf :math:`I1`.
k2 0.02 \ Constant of :math:`I2`.
R1 0 \ Free parameter.
Describes dependence of :math:`I_1` reset value on
:math:`I_1` value before spiking.
R2 1 \ Free parameter.
Describes dependence of :math:`I_2` reset value on
:math:`I_2` value before spiking.
A1 0 \ Free parameter.
A2 0 \ Free parameter.
============= ============== ======== ====================================================================

**Model Variables**

================== ================= =========================================================
**Variables name** **Initial Value** **Explanation**
------------------ ----------------- ---------------------------------------------------------
V -70 Membrane potential.
input 0 External and synaptic input current.
spike False Flag to mark whether the neuron is spiking.
V_th -50 Spiking threshold potential.
I1 0 Internal current 1.
I2 0 Internal current 2.
t_last_spike -1e7 Last spike time stamp.
================== ================= =========================================================

**References**

.. [1] Mihalaş, Ştefan, and Ernst Niebur. "A generalized linear
integrate-and-fire neural model produces diverse spiking
behaviors." Neural computation 21.3 (2009): 704-718.
.. [2] Teeter, Corinne, Ramakrishnan Iyer, Vilas Menon, Nathan
Gouwens, David Feng, Jim Berg, Aaron Szafer et al. "Generalized
leaky integrate-and-fire models classify multiple neuron types."
Nature communications 9, no. 1 (2018): 1-15.
"""

def __init__(
self,
size: Shape,
V_rest: Parameter = -70.,
V_reset: Parameter = -70.,
V_th_inf: Parameter = -50.,
V_th_reset: Parameter = -60.,
R: Parameter = 20.,
tau: Parameter = 20.,
a: Parameter = 0.,
b: Parameter = 0.01,
k1: Parameter = 0.2,
k2: Parameter = 0.02,
R1: Parameter = 0.,
R2: Parameter = 1.,
A1: Parameter = 0.,
A2: Parameter = 0.,
V_initializer: Union[Initializer, Callable, Tensor] = OneInit(-70.),
I1_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
I2_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
Vth_initializer: Union[Initializer, Callable, Tensor] = OneInit(-50.),
method: str = 'exp_auto',
name: str = None
):
# initialization
super(GIF, self).__init__(size=size, name=name)

# params
self.V_rest = V_rest
self.V_reset = V_reset
self.V_th_inf = V_th_inf
self.V_th_reset = V_th_reset
self.R = R
self.tau = tau
self.a = a
self.b = b
self.k1 = k1
self.k2 = k2
self.R1 = R1
self.R2 = R2
self.A1 = A1
self.A2 = A2

# variables
check_initializer(V_initializer, 'V_initializer')
check_initializer(I1_initializer, 'I1_initializer')
check_initializer(I2_initializer, 'I2_initializer')
check_initializer(Vth_initializer, 'Vth_initializer')
self.I1 = bm.Variable(init_param(I1_initializer, (self.num,)))
self.I2 = bm.Variable(init_param(I2_initializer, (self.num,)))
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.V_th = bm.Variable(init_param(Vth_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)

# integral
self.integral = odeint(method=method, f=self.derivative)

def dI1(self, I1, t):
return - self.k1 * I1

def dI2(self, I2, t):
return - self.k2 * I2

def dVth(self, V_th, t, V):
return self.a * (V - self.V_rest) - self.b * (V_th - self.V_th_inf)

def dV(self, V, t, I1, I2, I_ext):
return (- (V - self.V_rest) + self.R * I_ext + self.R * I1 + self.R * I2) / self.tau

@property
def derivative(self):
return JointEq([self.dI1, self.dI2, self.dVth, self.dV])

def update(self, _t, _dt):
I1, I2, V_th, V = self.integral(self.I1, self.I2, self.V_th, self.V, _t, self.input, dt=_dt)
spike = self.V_th <= V
V = bm.where(spike, self.V_reset, V)
I1 = bm.where(spike, self.R1 * I1 + self.A1, I1)
I2 = bm.where(spike, self.R2 * I2 + self.A2, I2)
reset_th = bm.logical_and(V_th < self.V_th_reset, spike)
V_th = bm.where(reset_th, self.V_th_reset, V_th)
self.spike.value = spike
self.I1.value = I1
self.I2.value = I2
self.V_th.value = V_th
self.V.value = V
self.input[:] = 0.


class Izhikevich(NeuGroup):
r"""The Izhikevich neuron model.

@@ -79,8 +924,20 @@ class Izhikevich(NeuGroup):
IEEE transactions on neural networks 15.5 (2004): 1063-1070.
"""

def __init__(self, size, a=0.02, b=0.20, c=-65., d=8., tau_ref=0.,
V_th=30., method='exp_auto', name=None):
def __init__(
self,
size: Shape,
a: Parameter = 0.02,
b: Parameter = 0.20,
c: Parameter = -65.,
d: Parameter = 8.,
tau_ref: Parameter = 0.,
V_th: Parameter = 30.,
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
u_initializer: Union[Initializer, Callable, Tensor] = OneInit(),
method: str = 'exp_auto',
name: str = None
):
# initialization
super(Izhikevich, self).__init__(size=size, name=name)

@@ -93,11 +950,13 @@ class Izhikevich(NeuGroup):
self.tau_ref = tau_ref

# variables
self.u = bm.Variable(bm.ones(self.num))
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
self.V = bm.Variable(bm.zeros(self.num))
check_initializer(V_initializer, 'V_initializer')
check_initializer(u_initializer, 'u_initializer')
self.u = bm.Variable(init_param(u_initializer, (self.num,)))
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)

# functions
@@ -157,7 +1016,7 @@ class HindmarshRose(NeuGroup):
>>> import matplotlib.pyplot as plt
>>>
>>> bp.math.set_dt(dt=0.01)
>>> bp.set_default_odeint('rk4')
>>> bp.ode.set_default_odeint('rk4')
>>>
>>> types = ['quiescence', 'spiking', 'bursting', 'irregular_spiking', 'irregular_bursting']
>>> bs = bp.math.array([1.0, 3.5, 2.5, 2.95, 2.8])
@@ -222,8 +1081,23 @@ class HindmarshRose(NeuGroup):
033128.
"""

def __init__(self, size, a=1., b=3., c=1., d=5., r=0.01, s=4., V_rest=-1.6,
V_th=1.0, method='exp_auto', name=None):
def __init__(
self,
size: Shape,
a: Parameter = 1.,
b: Parameter = 3.,
c: Parameter = 1.,
d: Parameter = 5.,
r: Parameter = 0.01,
s: Parameter = 4.,
V_rest: Parameter = -1.6,
V_th: Parameter = 1.0,
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
y_initializer: Union[Initializer, Callable, Tensor] = OneInit(-10.),
z_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
method: str = 'exp_auto',
name: str = None
):
# initialization
super(HindmarshRose, self).__init__(size=size, name=name)

@@ -238,9 +1112,12 @@ class HindmarshRose(NeuGroup):
self.V_rest = V_rest

# variables
self.z = bm.Variable(bm.zeros(self.num))
self.y = bm.Variable(bm.ones(self.num) * -10.)
self.V = bm.Variable(bm.zeros(self.num))
check_initializer(V_initializer, 'V_initializer')
check_initializer(y_initializer, 'y_initializer')
check_initializer(z_initializer, 'z_initializer')
self.z = bm.Variable(init_param(V_initializer, (self.num,)))
self.y = bm.Variable(init_param(y_initializer, (self.num,)))
self.V = bm.Variable(init_param(z_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)
@@ -269,3 +1146,138 @@ class HindmarshRose(NeuGroup):
self.y.value = y
self.z.value = z
self.input[:] = 0.


class FHN(NeuGroup):
r"""FitzHugh-Nagumo neuron model.

**Model Descriptions**

The FitzHugh–Nagumo model (FHN), named after Richard FitzHugh (1922–2007)
who suggested the system in 1961 [1]_ and J. Nagumo et al. who created the
equivalent circuit the following year, describes a prototype of an excitable
system (e.g., a neuron).

The motivation for the FitzHugh-Nagumo model was to isolate conceptually
the essentially mathematical properties of excitation and propagation from
the electrochemical properties of sodium and potassium ion flow. The model
consists of

- a *voltage-like variable* having cubic nonlinearity that allows regenerative
self-excitation via a positive feedback, and
- a *recovery variable* having a linear dynamics that provides a slower negative feedback.

.. math::

\begin{aligned}
{\dot {v}} &=v-{\frac {v^{3}}{3}}-w+RI_{\rm {ext}}, \\
\tau {\dot {w}}&=v+a-bw.
\end{aligned}

The FHN Model is an example of a relaxation oscillator
because, if the external stimulus :math:`I_{\text{ext}}`
exceeds a certain threshold value, the system will exhibit
a characteristic excursion in phase space, before the
variables :math:`v` and :math:`w` relax back to their rest values.
This behaviour is typical for spike generations (a short,
nonlinear elevation of membrane voltage :math:`v`,
diminished over time by a slower, linear recovery variable
:math:`w`) in a neuron after stimulation by an external
input current.

**Model Examples**

.. plot::
:include-source: True

>>> import brainpy as bp
>>> fhn = bp.dyn.FHN(1)
>>> runner = bp.dyn.DSRunner(fhn, inputs=('input', 1.), monitors=['V', 'w'])
>>> runner.run(100.)
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.w, legend='w')
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V', show=True)

**Model Parameters**

============= ============== ======== ========================
**Parameter** **Init Value** **Unit** **Explanation**
------------- -------------- -------- ------------------------
a 1 \ Positive constant
b 1 \ Positive constant
tau 10 ms Membrane time constant.
V_th 1.8 mV Threshold potential of spike.
============= ============== ======== ========================

**Model Variables**

================== ================= =========================================================
**Variables name** **Initial Value** **Explanation**
------------------ ----------------- ---------------------------------------------------------
V 0 Membrane potential.
w 0 A recovery variable which represents
the combined effects of sodium channel
de-inactivation and potassium channel
deactivation.
input 0 External and synaptic input current.
spike False Flag to mark whether the neuron is spiking.
t_last_spike -1e7 Last spike time stamp.
================== ================= =========================================================

**References**

.. [1] FitzHugh, Richard. "Impulses and physiological states in theoretical models of nerve membrane." Biophysical journal 1.6 (1961): 445-466.
.. [2] https://en.wikipedia.org/wiki/FitzHugh%E2%80%93Nagumo_model
.. [3] http://www.scholarpedia.org/article/FitzHugh-Nagumo_model

"""

def __init__(
self,
size: Shape,
a: Parameter = 0.7,
b: Parameter = 0.8,
tau: Parameter = 12.5,
Vth: Parameter = 1.8,
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
w_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
method: str = 'exp_auto',
name: str = None
):
# initialization
super(FHN, self).__init__(size=size, name=name)

# parameters
self.a = a
self.b = b
self.tau = tau
self.Vth = Vth

# variables
check_initializer(V_initializer, 'V_initializer')
check_initializer(w_initializer, 'w_initializer')
self.w = bm.Variable(init_param(w_initializer, (self.num,)))
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)

# integral
self.integral = odeint(method=method, f=self.derivative)

def dV(self, V, t, w, I_ext):
return V - V * V * V / 3 - w + I_ext

def dw(self, w, t, V):
return (V + self.a - self.b * w) / self.tau

@property
def derivative(self):
return JointEq([self.dV, self.dw])

def update(self, _t, _dt):
V, w = self.integral(self.V, self.w, _t, self.input, dt=_dt)
self.spike.value = bm.logical_and(V >= self.Vth, self.V < self.Vth)
self.t_last_spike.value = bm.where(self.spike, _t, self.t_last_spike)
self.V.value = V
self.w.value = w
self.input[:] = 0.

brainpy/dyn/runners/ds_runner.py → brainpy/dyn/runners.py View File

@@ -7,6 +7,7 @@ import numpy as np
import tqdm.auto
from jax.experimental.host_callback import id_tap

from brainpy.base.base import TensorCollector
from brainpy import math as bm
from brainpy.dyn import utils
from brainpy.dyn.base import DynamicalSystem
@@ -74,7 +75,6 @@ class DSRunner(Runner):
self.dyn_vars.update({'_i': self._i})
else:
self._i = None
self.dyn_vars.update(self.target.vars().unique())

# run function
self._run_func = self.build_run_function()
@@ -159,29 +159,33 @@ class DSRunner(Runner):
return_with_idx[key] = (data, bm.asarray(idx))

def func(_t, _dt):
res = {k: (v.flatten() if bm.ndim(v) > 1 else v) for k, v in return_without_idx.items()}
res = {k: (v.flatten() if bm.ndim(v) > 1 else v.value)
for k, v in return_without_idx.items()}
res.update({k: (v.flatten()[idx] if bm.ndim(v) > 1 else v[idx])
for k, (v, idx) in return_with_idx.items()})
return res

return func

def _run_one_step(self, t_and_dt):
_t, _dt = t_and_dt[0], t_and_dt[1]
self._input_step(_t=_t, _dt=_dt)
self.target.update(_t=_t, _dt=_dt)
def _run_one_step(self, _t):
self._input_step(_t=_t, _dt=self.dt)
self.target.update(_t=_t, _dt=self.dt)
if self.progress_bar:
id_tap(lambda *args: self._pbar.update(), ())
return self._monitor_step(_t=_t, _dt=_dt)
return self._monitor_step(_t=_t, _dt=self.dt)

def build_run_function(self):
if self.jit:
f_run = bm.make_loop(self._run_one_step, dyn_vars=self.dyn_vars, has_return=True)
dyn_vars = TensorCollector()
dyn_vars.update(self.dyn_vars)
dyn_vars.update(self.target.vars().unique())
f_run = bm.make_loop(self._run_one_step,
dyn_vars=dyn_vars,
has_return=True)
else:
def f_run(t_and_dt):
all_t, all_dt = t_and_dt
def f_run(all_t):
for i in range(all_t.shape[0]):
mon = self._run_one_step((all_t[i], all_dt[i]))
mon = self._run_one_step(all_t[i])
for k, v in mon.items():
self.mon.item_contents[k].append(v)
return None, {}
@@ -212,8 +216,7 @@ class DSRunner(Runner):
start_t = float(self._start_t)
end_t = float(start_t + duration)
# times
times = bm.arange(start_t, end_t, self.dt)
time_steps = bm.ones_like(times) * self.dt
times = np.arange(start_t, end_t, self.dt)
# build monitor
for key in self.mon.item_contents.keys():
self.mon.item_contents[key] = [] # reshape the monitor items
@@ -223,7 +226,7 @@ class DSRunner(Runner):
self._pbar.set_description(f"Running a duration of {round(float(duration), 3)} ({times.size} steps)",
refresh=True)
t0 = time.time()
_, hists = self._run_func([times.value, time_steps.value])
_, hists = self._run_func(times)
running_time = time.time() - t0
if self.progress_bar:
self._pbar.close()
@@ -277,23 +280,24 @@ class ReportRunner(DSRunner):

# Build the update function
if jit:
self._update_step = bm.jit(self.target.update, dyn_vars=self.dyn_vars)
dyn_vars = TensorCollector()
dyn_vars.update(self.dyn_vars)
dyn_vars.update(self.target.vars().unique())
self._update_step = bm.jit(self.target.update, dyn_vars=dyn_vars)
else:
self._update_step = self.target.update

def _run_one_step(self, t_and_dt):
_t, _dt = t_and_dt[0], t_and_dt[1]
self._input_step(_t=_t, _dt=_dt)
self._update_step(_t=_t, _dt=_dt)
def _run_one_step(self, _t):
self._input_step(_t, self.dt)
self._update_step(_t, self.dt)
if self.progress_bar:
self._pbar.update()
return self._monitor_step(_t=_t, _dt=_dt)
return self._monitor_step(_t, self.dt)

def build_run_function(self):
def f_run(t_and_dt):
all_t, all_dt = t_and_dt
def f_run(all_t):
for i in range(all_t.shape[0]):
mon = self._run_one_step((all_t[i], all_dt[i]))
mon = self._run_one_step(all_t[i])
for k, v in mon.items():
self.mon.item_contents[k].append(v)
return None, {}

+ 0
- 3
brainpy/dyn/runners/__init__.py View File

@@ -1,3 +0,0 @@
# -*- coding: utf-8 -*-

from .ds_runner import *

+ 2
- 0
brainpy/dyn/synapses/__init__.py View File

@@ -3,3 +3,5 @@
from .abstract_models import *
from .biological_models import *
from .learning_rules import *
from .delay_coupling import *


+ 24
- 5
brainpy/dyn/synapses/abstract_models.py View File

@@ -1,9 +1,10 @@
# -*- coding: utf-8 -*-

import brainpy.math as bm
from brainpy.dyn.base import NeuGroup
from brainpy.dyn.base import TwoEndConn, ConstantDelay
from brainpy.integrators.joint_eq import JointEq
from brainpy.integrators.ode import odeint
from brainpy.dyn.base import TwoEndConn, ConstantDelay

__all__ = [
'DeltaSynapse',
@@ -67,8 +68,17 @@ class DeltaSynapse(TwoEndConn):

"""

def __init__(self, pre, post, conn, delay=0., post_has_ref=False, w=1.,
post_key='V', name=None):
def __init__(
self,
pre: NeuGroup,
post: NeuGroup,
conn,
delay=0.,
post_has_ref=False,
w=1.,
post_key='V',
name=None
):
super(DeltaSynapse, self).__init__(pre=pre, post=post, conn=conn, name=name)
self.check_pre_attrs('spike')
self.check_post_attrs(post_key)
@@ -193,8 +203,17 @@ class ExpCUBA(TwoEndConn):
Cambridge: Cambridge UP, 2011. 172-95. Print.
"""

def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0,
method='exp_auto', name=None):
def __init__(
self,
pre: NeuGroup,
post: NeuGroup,
conn,
g_max=1.,
delay=0.,
tau=8.0,
method='exp_auto',
name=None
):
super(ExpCUBA, self).__init__(pre=pre, post=post, conn=conn, name=name)
self.check_pre_attrs('spike')
self.check_post_attrs('input', 'V')


+ 206
- 0
brainpy/dyn/synapses/delay_coupling.py View File

@@ -0,0 +1,206 @@
# -*- coding: utf-8 -*-

from typing import Optional, Union, Sequence, Dict, List

from jax import vmap

import brainpy.math as bm
from brainpy.dyn.base import TwoEndConn
from brainpy.initialize import Initializer, ZeroInit
from brainpy.tools.checking import check_sequence
from brainpy.types import Tensor

__all__ = [
'DelayCoupling',
'DiffusiveDelayCoupling',
'AdditiveDelayCoupling',
]


class DelayCoupling(TwoEndConn):
"""
Delay coupling base class.

coupling: str
The way of coupling.
gc: float
The global coupling strength.
signal_speed: float
Signal transmission speed between areas.
sc_mat: optional, tensor
Structural connectivity matrix. Adjacency matrix of coupling strengths,
will be normalized to 1. If not given, then a single node simulation
will be assumed. Default None
fl_mat: optional, tensor
Fiber length matrix. Will be used for computing the
delay matrix together with the signal transmission
speed parameter `signal_speed`. Default None.


"""

"""Global delay variables. Useful when the same target
variable is used in multiple mappings."""
global_delay_vars: Dict[str, bm.LengthDelay] = dict()

def __init__(
self,
pre,
post,
from_to: Union[str, Sequence[str]],
conn_mat: Tensor,
delay_mat: Optional[Tensor] = None,
delay_initializer: Initializer = ZeroInit(),
domain: str = 'local',
name: str = None
):
super(DelayCoupling, self).__init__(pre, post, name=name)

# local delay variables
self.local_delay_vars: Dict[str, bm.LengthDelay] = dict()

# domain
if domain not in ['global', 'local']:
raise ValueError('"domain" must be a string in ["global", "local"]. '
f'Bug we got {domain}.')
self.domain = domain

# pairs of (source, destination)
self.source_target_pairs: Dict[str, List[bm.Variable]] = dict()
source_vars = {}
if isinstance(from_to, str):
from_to = [from_to]
check_sequence(from_to, 'from_to', elem_type=str, allow_none=False)
for pair in from_to:
splits = [v.strip() for v in pair.split('->')]
if len(splits) != 2:
raise ValueError('The (source, target) pair in "from_to" '
'should be defined as "a -> b".')
if not hasattr(self.pre, splits[0]):
raise ValueError(f'"{splits[0]}" is not defined in pre-synaptic group {self.pre.name}')
if not hasattr(self.post, splits[1]):
raise ValueError(f'"{splits[1]}" is not defined in post-synaptic group {self.post.name}')
source = f'{self.pre.name}.{splits[0]}'
target = getattr(self.post, splits[1])
if splits[0] not in self.source_target_pairs:
self.source_target_pairs[source] = [target]
source_vars[source] = getattr(self.pre, splits[0])
if not isinstance(source_vars[source], bm.Variable):
raise ValueError(f'The target variable {source} for delay should '
f'be an instance of brainpy.math.Variable, while '
f'we got {type(source_vars[source])}')
else:
if target in self.source_target_pairs:
raise ValueError(f'{pair} has been defined twice in {from_to}.')
self.source_target_pairs[source].append(target)

# Connection matrix
conn_mat = bm.asarray(conn_mat)
required_shape = (self.post.num, self.pre.num)
if conn_mat.shape != required_shape:
raise ValueError(f'we expect the structural connection matrix has the shape of '
f'(post.num, pre.num), i.e., {required_shape}, '
f'while we got {conn_mat.shape}.')
self.conn_mat = bm.asarray(conn_mat)
bm.fill_diagonal(self.conn_mat, 0)

# Delay matrix
if delay_mat is None:
self.delay_mat = bm.zeros(required_shape, dtype=bm.int_)
else:
if delay_mat.shape != required_shape:
raise ValueError(f'we expect the fiber length matrix has the shape of '
f'(post.num, pre.num), i.e., {required_shape}. '
f'While we got {delay_mat.shape}.')
self.delay_mat = bm.asarray(delay_mat, dtype=bm.int_)

# delay variables
num_delay_step = int(self.delay_mat.max())
for var in self.source_target_pairs.keys():
if domain == 'local':
variable = source_vars[var]
shape = (num_delay_step,) + variable.shape
delay_data = delay_initializer(shape, dtype=variable.dtype)
self.local_delay_vars[var] = bm.LengthDelay(variable, num_delay_step, delay_data)
else:
if var not in self.global_delay_vars:
variable = source_vars[var]
shape = (num_delay_step,) + variable.shape
delay_data = delay_initializer(shape, dtype=variable.dtype)
self.global_delay_vars[var] = bm.LengthDelay(variable, num_delay_step, delay_data)
# save into local delay vars when first seen "var",
# for later update current value!
self.local_delay_vars[var] = self.global_delay_vars[var]
else:
if self.global_delay_vars[var].delay_len < num_delay_step:
variable = source_vars[var]
shape = (num_delay_step,) + variable.shape
delay_data = delay_initializer(shape, dtype=variable.dtype)
self.global_delay_vars[var].init(variable, num_delay_step, delay_data)

self.register_implicit_nodes(self.local_delay_vars)
self.register_implicit_nodes(self.global_delay_vars)

def update(self, _t, _dt):
raise NotImplementedError('Must implement the update() function by users.')


class DiffusiveDelayCoupling(DelayCoupling):
def update(self, _t, _dt):
for source, targets in self.source_target_pairs.items():
# delay variable
if self.domain == 'local':
delay_var: bm.LengthDelay = self.local_delay_vars[source]
elif self.domain == 'global':
delay_var: bm.LengthDelay = self.global_delay_vars[source]
else:
raise ValueError(f'Unknown domain: {self.domain}')

# current data
name, var = source.split('.')
assert name == self.pre.name
variable = getattr(self.pre, var)

# delays
f = vmap(lambda i: delay_var(self.delay_mat[i], bm.arange(self.pre.num))) # (pre.num,)
delays = f(bm.arange(self.post.num).value)
diffusive = delays - bm.expand_dims(variable, axis=1) # (post.num, pre.num)
diffusive = (self.conn_mat * diffusive).sum(axis=1)

# output to target variable
for target in targets:
target.value += diffusive

# update
if source in self.local_delay_vars:
delay_var.update(variable)


class AdditiveDelayCoupling(DelayCoupling):
def update(self, _t, _dt):
for source, targets in self.source_target_pairs.items():
# delay variable
if self.domain == 'local':
delay_var: bm.LengthDelay = self.local_delay_vars[source]
elif self.domain == 'global':
delay_var: bm.LengthDelay = self.global_delay_vars[source]
else:
raise ValueError(f'Unknown domain: {self.domain}')

# current data
name, var = source.split('.')
assert name == self.pre.name
variable = getattr(self.pre, var)

# delay function
f = vmap(lambda i: delay_var(self.delay_mat[i], bm.arange(self.pre.num))) # (pre.num,)
delays = f(bm.arange(self.post.num)) # (post.num, pre.num)
additive = (self.conn_mat * delays).sum(axis=1)

# output to target variable
for target in targets:
target.value += additive

# update
if source in self.local_delay_vars:
delay_var.update(variable)

+ 2
- 1
brainpy/errors.py View File

@@ -101,7 +101,8 @@ class JaxTracerError(MathError):
else:
raise ValueError

msg += 'While there are changed variables which are not wrapped into "dyn_vars". Please check!'
# msg += 'While there are changed variables which are not wrapped into "dyn_vars". Please check!'
msg = 'While there are changed variables which are not wrapped into "dyn_vars". Please check!'

super(JaxTracerError, self).__init__(msg)



+ 1
- 0
brainpy/initialize/__init__.py View File

@@ -6,6 +6,7 @@ You can access them through ``brainpy.init.XXX``.
"""

from .base import *
from .generic import *
from .random_inits import *
from .regular_inits import *
from .decay_inits import *

+ 3
- 3
brainpy/initialize/base.py View File

@@ -13,7 +13,7 @@ class Initializer(abc.ABC):
"""Base Initialization Class."""

@abc.abstractmethod
def __call__(self, shape):
def __call__(self, shape, dtype=None):
raise NotImplementedError


@@ -21,7 +21,7 @@ class InterLayerInitializer(Initializer):
"""The superclass of Initializers that initialize the weights between two layers."""

@abc.abstractmethod
def __call__(self, shape):
def __call__(self, shape, dtype=None):
raise NotImplementedError


@@ -29,5 +29,5 @@ class IntraLayerInitializer(Initializer):
"""The superclass of Initializers that initialize the weights within a layer."""

@abc.abstractmethod
def __call__(self, shape):
def __call__(self, shape, dtype=None):
raise NotImplementedError

+ 46
- 0
brainpy/initialize/generic.py View File

@@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-

from typing import Union, Callable

import jax.numpy as jnp
import numpy as onp

import brainpy.math as bm
from brainpy.tools.others import to_size
from brainpy.types import Shape
from .base import Initializer

__all__ = [
'init_param',
]


def init_param(param: Union[Callable, Initializer, bm.ndarray, jnp.ndarray],
size: Shape):
"""Initialize parameters.

Parameters
----------
param: callable, Initializer, bm.ndarray, jnp.ndarray
The initialization of the parameter.
- If it is None, the created parameter will be None.
- If it is a callable function :math:`f`, the ``f(size)`` will be returned.
- If it is an instance of :py:class:`brainpy.init.Initializer``, the ``f(size)`` will be returned.
- If it is a tensor, then this function check whether ``tensor.shape`` is equal to the given ``size``.
size: int, sequence of int
The shape of the parameter.
"""
size = to_size(size)
if param is None:
return None
elif callable(param):
param = param(size)
elif isinstance(param, (onp.ndarray, jnp.ndarray)):
param = bm.asarray(param)
elif isinstance(param, (bm.JaxArray,)):
param = param
else:
raise ValueError(f'Unknown param type {type(param)}: {param}')
assert param.shape == size, f'"param.shape" is not the required size {size}'
return param


+ 7
- 8
brainpy/initialize/random_inits.py View File

@@ -40,7 +40,7 @@ class Normal(InterLayerInitializer):
def __init__(self, scale=1., seed=None):
super(Normal, self).__init__()
self.scale = scale
self.rng = bm.random.RandomState(seed=seed)
self.rng = np.random.RandomState(seed=seed)

def __call__(self, shape, dtype=None):
shape = [tools.size2num(d) for d in shape]
@@ -64,7 +64,7 @@ class Uniform(InterLayerInitializer):
super(Uniform, self).__init__()
self.min_val = min_val
self.max_val = max_val
self.rng = bm.random.RandomState(seed=seed)
self.rng = np.random.RandomState(seed=seed)

def __call__(self, shape, dtype=None):
shape = [tools.size2num(d) for d in shape]
@@ -79,7 +79,7 @@ class VarianceScaling(InterLayerInitializer):
self.in_axis = in_axis
self.out_axis = out_axis
self.distribution = distribution
self.rng = bm.random.RandomState(seed=seed)
self.rng = np.random.RandomState(seed=seed)

def __call__(self, shape, dtype=None):
shape = [tools.size2num(d) for d in shape]
@@ -94,18 +94,17 @@ class VarianceScaling(InterLayerInitializer):
raise ValueError("invalid mode for variance scaling initializer: {}".format(self.mode))
variance = bm.array(self.scale / denominator, dtype=dtype)
if self.distribution == "truncated_normal":
from scipy.stats import truncnorm
# constant is stddev of standard normal truncated to (-2, 2)
stddev = bm.sqrt(variance) / bm.array(.87962566103423978, dtype)
res = self.rng.truncated_normal(-2, 2, shape) * stddev
return bm.asarray(res, dtype=dtype)
res = truncnorm(-2, 2).rvs(shape) * stddev
elif self.distribution == "normal":
res = self.rng.normal(size=shape) * bm.sqrt(variance)
return bm.asarray(res, dtype=dtype)
elif self.distribution == "uniform":
res = self.rng.uniform(low=-1, high=1, size=shape) * bm.sqrt(3 * variance)
return bm.asarray(res, dtype=dtype)
else:
raise ValueError("invalid distribution for variance scaling initializer")
return bm.asarray(res, dtype=dtype)


class KaimingUniform(VarianceScaling):
@@ -180,7 +179,7 @@ class Orthogonal(InterLayerInitializer):
super(Orthogonal, self).__init__()
self.scale = scale
self.axis = axis
self.rng = bm.random.RandomState(seed=seed)
self.rng = np.random.RandomState(seed=seed)

def __call__(self, shape, dtype=None):
shape = [tools.size2num(d) for d in shape]


+ 1
- 199
brainpy/inputs/__init__.py View File

@@ -6,203 +6,5 @@ This module provides various methods to form current inputs.
You can access them through ``brainpy.inputs.XXX``.
"""

import numpy as np

from brainpy import math as bm

__all__ = [
'section_input',
'constant_input', 'constant_current',
'spike_input', 'spike_current',
'ramp_input', 'ramp_current',
]


def section_input(values, durations, dt=None, return_length=False):
"""Format an input current with different sections.

For example:

If you want to get an input where the size is 0 bwteen 0-100 ms,
and the size is 1. between 100-200 ms.

>>> section_input(values=[0, 1],
>>> durations=[100, 100])

Parameters
----------
values : list, np.ndarray
The current values for each period duration.
durations : list, np.ndarray
The duration for each period.
dt : float
Default is None.
return_length : bool
Return the final duration length.

Returns
-------
current_and_duration : tuple
(The formatted current, total duration)
"""
assert len(durations) == len(values), f'"values" and "durations" must be the same length, while ' \
f'we got {len(values)} != {len(durations)}.'

dt = bm.get_dt() if dt is None else dt

# get input current shape, and duration
I_duration = sum(durations)
I_shape = ()
for val in values:
shape = bm.shape(val)
if len(shape) > len(I_shape):
I_shape = shape

# get the current
start = 0
I_current = bm.zeros((int(np.ceil(I_duration / dt)),) + I_shape, dtype=bm.float_)
for c_size, duration in zip(values, durations):
length = int(duration / dt)
I_current[start: start + length] = c_size
start += length

if return_length:
return I_current, I_duration
else:
return I_current


def constant_input(I_and_duration, dt=None):
"""Format constant input in durations.

For example:

If you want to get an input where the size is 0 bwteen 0-100 ms,
and the size is 1. between 100-200 ms.

>>> import brainpy.math as bm
>>> constant_input([(0, 100), (1, 100)])
>>> constant_input([(bm.zeros(100), 100), (bm.random.rand(100), 100)])

Parameters
----------
I_and_duration : list
This parameter receives the current size and the current
duration pairs, like `[(Isize1, duration1), (Isize2, duration2)]`.
dt : float
Default is None.

Returns
-------
current_and_duration : tuple
(The formatted current, total duration)
"""
dt = bm.get_dt() if dt is None else dt

# get input current dimension, shape, and duration
I_duration = 0.
I_shape = ()
for I in I_and_duration:
I_duration += I[1]
shape = bm.shape(I[0])
if len(shape) > len(I_shape):
I_shape = shape

# get the current
start = 0
I_current = bm.zeros((int(np.ceil(I_duration / dt)),) + I_shape, dtype=bm.float_)
for c_size, duration in I_and_duration:
length = int(duration / dt)
I_current[start: start + length] = c_size
start += length
return I_current, I_duration


constant_current = constant_input


def spike_input(sp_times, sp_lens, sp_sizes, duration, dt=None):
"""Format current input like a series of short-time spikes.

For example:

If you want to generate a spike train at 10 ms, 20 ms, 30 ms, 200 ms, 300 ms,
and each spike lasts 1 ms and the spike current is 0.5, then you can use the
following funtions:

>>> spike_input(sp_times=[10, 20, 30, 200, 300],
>>> sp_lens=1., # can be a list to specify the spike length at each point
>>> sp_sizes=0.5, # can be a list to specify the current size at each point
>>> duration=400.)

Parameters
----------
sp_times : list, tuple
The spike time-points. Must be an iterable object.
sp_lens : int, float, list, tuple
The length of each point-current, mimicking the spike durations.
sp_sizes : int, float, list, tuple
The current sizes.
duration : int, float
The total current duration.
dt : float
The default is None.

Returns
-------
current : bm.ndarray
The formatted input current.
"""
dt = bm.get_dt() if dt is None else dt
assert isinstance(sp_times, (list, tuple))
if isinstance(sp_lens, (float, int)):
sp_lens = [sp_lens] * len(sp_times)
if isinstance(sp_sizes, (float, int)):
sp_sizes = [sp_sizes] * len(sp_times)

current = bm.zeros(int(np.ceil(duration / dt)), dtype=bm.float_)
for time, dur, size in zip(sp_times, sp_lens, sp_sizes):
pp = int(time / dt)
p_len = int(dur / dt)
current[pp: pp + p_len] = size
return current


spike_current = spike_input


def ramp_input(c_start, c_end, duration, t_start=0, t_end=None, dt=None):
"""Get the gradually changed input current.

Parameters
----------
c_start : float
The minimum (or maximum) current size.
c_end : float
The maximum (or minimum) current size.
duration : int, float
The total duration.
t_start : float
The ramped current start time-point.
t_end : float
The ramped current end time-point. Default is the None.
dt : float, int, optional
The numerical precision.

Returns
-------
current : bm.ndarray
The formatted current
"""
dt = bm.get_dt() if dt is None else dt
t_end = duration if t_end is None else t_end

current = bm.zeros(int(np.ceil(duration / dt)), dtype=bm.float_)
p1 = int(np.ceil(t_start / dt))
p2 = int(np.ceil(t_end / dt))
current[p1: p2] = bm.array(bm.linspace(c_start, c_end, p2 - p1), dtype=bm.float_)
return current


ramp_current = ramp_input
from .currents import *


+ 386
- 0
brainpy/inputs/currents.py View File

@@ -0,0 +1,386 @@
# -*- coding: utf-8 -*-

import numpy as np

from brainpy import math as bm
from brainpy.tools.checking import check_float, check_integer

__all__ = [
'section_input',
'constant_input', 'constant_current',
'spike_input', 'spike_current',
'ramp_input', 'ramp_current',
'wiener_process',
'ou_process',
'sinusoidal_input',
'square_input',
]


def section_input(values, durations, dt=None, return_length=False):
"""Format an input current with different sections.

For example:

If you want to get an input where the size is 0 bwteen 0-100 ms,
and the size is 1. between 100-200 ms.

>>> section_input(values=[0, 1],
>>> durations=[100, 100])

Parameters
----------
values : list, np.ndarray
The current values for each period duration.
durations : list, np.ndarray
The duration for each period.
dt : float
Default is None.
return_length : bool
Return the final duration length.

Returns
-------
current_and_duration : tuple
(The formatted current, total duration)
"""
assert len(durations) == len(values), f'"values" and "durations" must be the same length, while ' \
f'we got {len(values)} != {len(durations)}.'

dt = bm.get_dt() if dt is None else dt

# get input current shape, and duration
I_duration = sum(durations)
I_shape = ()
for val in values:
shape = bm.shape(val)
if len(shape) > len(I_shape):
I_shape = shape

# get the current
start = 0
I_current = bm.zeros((int(np.ceil(I_duration / dt)),) + I_shape, dtype=bm.float_)
for c_size, duration in zip(values, durations):
length = int(duration / dt)
I_current[start: start + length] = c_size
start += length

if return_length:
return I_current, I_duration
else:
return I_current


def constant_input(I_and_duration, dt=None):
"""Format constant input in durations.

For example:

If you want to get an input where the size is 0 bwteen 0-100 ms,
and the size is 1. between 100-200 ms.

>>> import brainpy.math as bm
>>> constant_input([(0, 100), (1, 100)])
>>> constant_input([(bm.zeros(100), 100), (bm.random.rand(100), 100)])

Parameters
----------
I_and_duration : list
This parameter receives the current size and the current
duration pairs, like `[(Isize1, duration1), (Isize2, duration2)]`.
dt : float
Default is None.

Returns
-------
current_and_duration : tuple
(The formatted current, total duration)
"""
dt = bm.get_dt() if dt is None else dt

# get input current dimension, shape, and duration
I_duration = 0.
I_shape = ()
for I in I_and_duration:
I_duration += I[1]
shape = bm.shape(I[0])
if len(shape) > len(I_shape):
I_shape = shape

# get the current
start = 0
I_current = bm.zeros((int(np.ceil(I_duration / dt)),) + I_shape, dtype=bm.float_)
for c_size, duration in I_and_duration:
length = int(duration / dt)
I_current[start: start + length] = c_size
start += length
return I_current, I_duration


constant_current = constant_input


def spike_input(sp_times, sp_lens, sp_sizes, duration, dt=None):
"""Format current input like a series of short-time spikes.

For example:

If you want to generate a spike train at 10 ms, 20 ms, 30 ms, 200 ms, 300 ms,
and each spike lasts 1 ms and the spike current is 0.5, then you can use the
following funtions:

>>> spike_input(sp_times=[10, 20, 30, 200, 300],
>>> sp_lens=1., # can be a list to specify the spike length at each point
>>> sp_sizes=0.5, # can be a list to specify the current size at each point
>>> duration=400.)

Parameters
----------
sp_times : list, tuple
The spike time-points. Must be an iterable object.
sp_lens : int, float, list, tuple
The length of each point-current, mimicking the spike durations.
sp_sizes : int, float, list, tuple
The current sizes.
duration : int, float
The total current duration.
dt : float
The default is None.

Returns
-------
current : bm.ndarray
The formatted input current.
"""
dt = bm.get_dt() if dt is None else dt
assert isinstance(sp_times, (list, tuple))
if isinstance(sp_lens, (float, int)):
sp_lens = [sp_lens] * len(sp_times)
if isinstance(sp_sizes, (float, int)):
sp_sizes = [sp_sizes] * len(sp_times)

current = bm.zeros(int(np.ceil(duration / dt)), dtype=bm.float_)
for time, dur, size in zip(sp_times, sp_lens, sp_sizes):
pp = int(time / dt)
p_len = int(dur / dt)
current[pp: pp + p_len] = size
return current


spike_current = spike_input


def ramp_input(c_start, c_end, duration, t_start=0, t_end=None, dt=None):
"""Get the gradually changed input current.

Parameters
----------
c_start : float
The minimum (or maximum) current size.
c_end : float
The maximum (or minimum) current size.
duration : int, float
The total duration.
t_start : float
The ramped current start time-point.
t_end : float
The ramped current end time-point. Default is the None.
dt : float, int, optional
The numerical precision.

Returns
-------
current : bm.ndarray
The formatted current
"""
dt = bm.get_dt() if dt is None else dt
t_end = duration if t_end is None else t_end

current = bm.zeros(int(np.ceil(duration / dt)), dtype=bm.float_)
p1 = int(np.ceil(t_start / dt))
p2 = int(np.ceil(t_end / dt))
current[p1: p2] = bm.array(bm.linspace(c_start, c_end, p2 - p1), dtype=bm.float_)
return current


ramp_current = ramp_input


def wiener_process(duration, dt=None, n=1, t_start=0., t_end=None, seed=None):
"""Stimulus sampled from a Wiener process, i.e.
drawn from standard normal distribution N(0, sqrt(dt)).

Parameters
----------
duration: float
The input duration.
dt: float
The numerical precision.
n: int
The variable number.
t_start: float
The start time.
t_end: float
The end time.
seed: int
The noise seed.
"""
dt = bm.get_dt() if dt is None else dt
check_float(dt, 'dt', allow_none=False, min_bound=0.)
check_integer(n, 'n', allow_none=False, min_bound=0)
rng = bm.random.RandomState(seed)
t_end = duration if t_end is None else t_end
i_start = int(t_start / dt)
i_end = int(t_end / dt)
noises = rng.standard_normal((i_end - i_start, n)) * bm.sqrt(dt)
currents = bm.zeros((int(duration / dt), n))
currents[i_start: i_end] = noises
return currents


def ou_process(mean, sigma, tau, duration, dt=None, n=1, t_start=0., t_end=None, seed=None):
r"""Ornstein–Uhlenbeck input.

.. math::

dX = (mu - X)/\tau * dt + \sigma*dW

Parameters
----------
mean: float
Drift of the OU process.
sigma: float
Standard deviation of the Wiener process, i.e. strength of the noise.
tau: float
Timescale of the OU process, in ms.
duration: float
The input duration.
dt: float
The numerical precision.
n: int
The variable number.
t_start: float
The start time.
t_end: float
The end time.

"""
dt = bm.get_dt() if dt is None else dt
dt_sqrt = bm.sqrt(dt)
check_float(dt, 'dt', allow_none=False, min_bound=0.)
check_integer(n, 'n', allow_none=False, min_bound=0)
rng = bm.random.RandomState(seed)
x = bm.Variable(bm.ones(n) * mean)

def _f(t):
x.value = x + dt * ((mean - x) / tau) + sigma * dt_sqrt * rng.standard_normal(n)

f = bm.make_loop(_f, dyn_vars=[x, rng], out_vars=x)
noises = f(bm.arange(t_start, t_end, dt))

t_end = duration if t_end is None else t_end
i_start = int(t_start / dt)
i_end = int(t_end / dt)
currents = bm.zeros((int(duration / dt), n))
currents[i_start: i_end] = noises
return currents


def sinusoidal_input(amplitude, frequency, duration, dt=None, t_start=0., t_end=None, dc_bias=False):
"""Sinusoidal input.

Parameters
----------
amplitude: float
Amplitude of the sinusoid.
frequency: float
Frequency of the sinus oscillation, in Hz
duration: float
The input duration.
t_start: float
The start time.
t_end: float
The end time.
dt: float
The numerical precision.
dc_bias: bool
Whether the sinusoid oscillates around 0 (False), or
has a positive DC bias, thus non-negative (True).
"""
dt = bm.get_dt() if dt is None else dt
check_float(dt, 'dt', allow_none=False, min_bound=0.)
if t_end is None:
t_end = duration
times = bm.arange(0, t_end-t_start, dt)
start_i = int(t_start/dt)
end_i = int(t_end/dt)
sin_inputs = amplitude * bm.sin(2 * bm.pi * times * (frequency / 1000.0))
if dc_bias:
sin_inputs += amplitude
currents = bm.zeros(int(duration / dt))
currents[start_i:end_i] = sin_inputs
return currents


def _square(t, duty=0.5):
t, w = np.asarray(t), np.asarray(duty)
w = np.asarray(w + (t - t))
t = np.asarray(t + (w - w))
if t.dtype.char in ['fFdD']:
ytype = t.dtype.char
else:
ytype = 'd'

y = np.zeros(t.shape, ytype)

# width must be between 0 and 1 inclusive
mask1 = (w > 1) | (w < 0)
np.place(y, mask1, np.nan)

# on the interval 0 to duty*2*pi function is 1
tmod = np.mod(t, 2 * np.pi)
mask2 = (1 - mask1) & (tmod < w * 2 * np.pi)
np.place(y, mask2, 1)

# on the interval duty*2*pi to 2*pi function is
# (pi*(w+1)-tmod) / (pi*(1-w))
mask3 = (1 - mask1) & (1 - mask2)
np.place(y, mask3, -1)
return y


def square_input(amplitude, frequency, duration, dt=None, dc_bias=False, t_start=None, t_end=None):
"""Oscillatory square input.

Parameters
----------
amplitude: float
Amplitude of the square oscillation.
frequency: float
Frequency of the square oscillation, in Hz.
duration: float
The input duration.
t_start: float
The start time.
t_end: float
The end time.
dt: float
The numerical precision.
dc_bias: bool
Whether the sinusoid oscillates around 0 (False), or
has a positive DC bias, thus non-negative (True).
"""
dt = bm.get_dt() if dt is None else dt
check_float(dt, 'dt', allow_none=False, min_bound=0.)
if t_end is None:
t_end = duration
times = bm.arange(0, t_end - t_start, dt)
currents = bm.zeros(int(duration / dt))
start_i = int(t_start/dt)
end_i = int(t_end/dt)
sin_inputs = amplitude * _square(2 * bm.pi * times * (frequency / 1000.0))
if dc_bias:
sin_inputs += amplitude
currents[start_i:end_i] = sin_inputs
return currents


+ 10
- 1
brainpy/integrators/__init__.py View File

@@ -7,6 +7,7 @@ including:
- ordinary differential equations (ODEs)
- stochastic differential equations (SDEs)
- delay differential equations (DDEs)
- fractional differential equations (FDEs)

Details please see the following.
"""
@@ -41,6 +42,14 @@ from .dde.generic import (ddeint,
set_default_ddeint,
register_dde_integrator)

# others
# FDE tools
from . import fde
from .fde.base import FDEIntegrator
from .fde.generic import (fdeint,
get_default_fdeint,
set_default_fdeint,
register_fde_integrator)


# PDE tools
from . import pde

+ 8
- 2
brainpy/integrators/base.py View File

@@ -34,12 +34,13 @@ class Integrator(AbstractIntegrator):
self._dt = dt
check_float(dt, 'dt', allow_none=False, allow_int=True)
self._variables = variables # variables
self._parameters = parameters # parameters
self._arguments = list(arguments) + [f'{DT}={self.dt}'] # arguments
self._parameters = parameters # parameters
self._arguments = list(arguments) + [f'{DT}={self.dt}'] # arguments
self._integral = None # integral function

@property
def dt(self):
"""The numerical integration precision."""
return self._dt

@dt.setter
@@ -48,6 +49,7 @@ class Integrator(AbstractIntegrator):

@property
def variables(self):
"""The variables defined in the differential equation."""
return self._variables

@variables.setter
@@ -56,6 +58,7 @@ class Integrator(AbstractIntegrator):

@property
def parameters(self):
"""The parameters defined in the differential equation."""
return self._parameters

@parameters.setter
@@ -64,6 +67,7 @@ class Integrator(AbstractIntegrator):

@property
def arguments(self):
"""All arguments when calling the numer integrator of the differential equation."""
return self._arguments

@arguments.setter
@@ -72,6 +76,7 @@ class Integrator(AbstractIntegrator):

@property
def integral(self):
"""The integral function."""
return self._integral

@integral.setter
@@ -79,6 +84,7 @@ class Integrator(AbstractIntegrator):
self.set_integral(f)

def set_integral(self, f):
"""Set the integral function."""
if not callable(f):
raise ValueError(f'integral function must be a callable function, '
f'but we got {type(f)}: {f}')


+ 17
- 5
brainpy/integrators/dde/base.py View File

@@ -25,7 +25,7 @@ class DDEIntegrator(Integrator):
dt: Union[float, int] = None,
name: str = None,
show_code: bool = False,
state_delays: Dict[str, bm.FixedLenDelay] = None,
state_delays: Dict[str, bm.TimeDelay] = None,
neutral_delays: Dict[str, bm.NeutralDelay] = None,
):
dt = bm.get_dt() if dt is None else dt
@@ -59,7 +59,9 @@ class DDEIntegrator(Integrator):
# delays
self._state_delays = dict()
if state_delays is not None:
check_dict_data(state_delays, key_type=str, val_type=bm.FixedLenDelay)
check_dict_data(state_delays,
key_type=str,
val_type=(bm.TimeDelay, bm.LengthDelay))
for key, delay in state_delays.items():
if key not in self.variables:
raise DiffEqError(f'"{key}" is not defined in the variables: {self.variables}')
@@ -67,7 +69,9 @@ class DDEIntegrator(Integrator):
self.register_implicit_nodes(self._state_delays)
self._neutral_delays = dict()
if neutral_delays is not None:
check_dict_data(neutral_delays, key_type=str, val_type=bm.NeutralDelay)
check_dict_data(neutral_delays,
key_type=str,
val_type=bm.NeutralDelay)
for key, delay in neutral_delays.items():
if key not in self.variables:
raise DiffEqError(f'"{key}" is not defined in the variables: {self.variables}')
@@ -111,11 +115,19 @@ class DDEIntegrator(Integrator):
else:
new_dvars = {k: new_dvars[i] for i, k in enumerate(self.variables)}
for key, delay in self.neutral_delays.items():
delay.update(kwargs['t'] + dt, new_dvars[key])
if isinstance(delay, bm.LengthDelay):
delay.update(new_dvars[key])
elif isinstance(delay, bm.TimeDelay):
delay.update(kwargs['t'] + dt, new_dvars[key])
raise ValueError('Unknown delay variable.')

# update state delay variables
for key, delay in self.state_delays.items():
delay.update(kwargs['t'] + dt, dict_vars[key])
if isinstance(delay, bm.LengthDelay):
delay.update(dict_vars[key])
elif isinstance(delay, bm.TimeDelay):
delay.update(kwargs['t'] + dt, dict_vars[key])
raise ValueError('Unknown delay variable.')

return new_vars



+ 37
- 2
brainpy/integrators/dde/explicit_rk.py View File

@@ -4,6 +4,7 @@ from brainpy.integrators.constants import F, DT
from brainpy.integrators.dde.base import DDEIntegrator
from brainpy.integrators.ode import common
from brainpy.integrators.utils import compile_code, check_kws
from brainpy.integrators.dde.generic import register_dde_integrator

__all__ = [
'ExplicitRKIntegrator',
@@ -47,8 +48,6 @@ class ExplicitRKIntegrator(DDEIntegrator):
def integral(*vars, **kwargs):
pass



self.build()

def build(self):
@@ -72,24 +71,36 @@ class Euler(ExplicitRKIntegrator):
C = [0]


register_dde_integrator('euler', Euler)


class MidPoint(ExplicitRKIntegrator):
A = [(), (0.5,)]
B = [0, 1]
C = [0, 0.5]


register_dde_integrator('midpoint', MidPoint)


class Heun2(ExplicitRKIntegrator):
A = [(), (1,)]
B = [0.5, 0.5]
C = [0, 1]


register_dde_integrator('heun2', Heun2)


class Ralston2(ExplicitRKIntegrator):
A = [(), ('2/3',)]
B = [0.25, 0.75]
C = [0, '2/3']


register_dde_integrator('ralston2', Ralston2)


class RK2(ExplicitRKIntegrator):
def __init__(self, f, beta=2 / 3, var_type=None, dt=None, name=None, show_code=False):
self.A = [(), (beta,)]
@@ -98,43 +109,67 @@ class RK2(ExplicitRKIntegrator):
super(RK2, self).__init__(f=f, var_type=var_type, dt=dt, name=name, show_code=show_code)


register_dde_integrator('rk2', RK2)


class RK3(ExplicitRKIntegrator):
A = [(), (0.5,), (-1, 2)]
B = ['1/6', '2/3', '1/6']
C = [0, 0.5, 1]


register_dde_integrator('rk3', RK3)


class Heun3(ExplicitRKIntegrator):
A = [(), ('1/3',), (0, '2/3')]
B = [0.25, 0, 0.75]
C = [0, '1/3', '2/3']


register_dde_integrator('heun3', Heun3)


class Ralston3(ExplicitRKIntegrator):
A = [(), (0.5,), (0, 0.75)]
B = ['2/9', '1/3', '4/9']
C = [0, 0.5, 0.75]


register_dde_integrator('ralston3', Ralston3)


class SSPRK3(ExplicitRKIntegrator):
A = [(), (1,), (0.25, 0.25)]
B = ['1/6', '1/6', '2/3']
C = [0, 1, 0.5]


register_dde_integrator('ssprk3', SSPRK3)


class RK4(ExplicitRKIntegrator):
A = [(), (0.5,), (0., 0.5), (0., 0., 1)]
B = ['1/6', '1/3', '1/3', '1/6']
C = [0, 0.5, 0.5, 1]


register_dde_integrator('rk4', RK4)


class Ralston4(ExplicitRKIntegrator):
A = [(), (.4,), (.29697761, .15875964), (.21810040, -3.05096516, 3.83286476)]
B = [.17476028, -.55148066, 1.20553560, .17118478]
C = [0, .4, .45573725, 1]


register_dde_integrator('ralston4', Ralston4)


class RK4Rule38(ExplicitRKIntegrator):
A = [(), ('1/3',), ('-1/3', '1'), (1, -1, 1)]
B = [0.125, 0.375, 0.375, 0.125]
C = [0, '1/3', '2/3', 1]


register_dde_integrator('rk4_38rule', RK4Rule38)

+ 1
- 15
brainpy/integrators/dde/generic.py View File

@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-

from .base import DDEIntegrator
from .explicit_rk import *

__all__ = [
'ddeint',
@@ -12,19 +11,6 @@ __all__ = [
]

name2method = {
# explicit RK
'euler': Euler, 'Euler': Euler,
'midpoint': MidPoint, 'MidPoint': MidPoint,
'heun2': Heun2, 'Heun2': Heun2,
'ralston2': Ralston2, 'Ralston2': Ralston2,
'rk2': RK2, 'RK2': RK2,
'rk3': RK3, 'RK3': RK3,
'heun3': Heun3, 'Heun3': Heun3,
'ralston3': Ralston3, 'Ralston3': Ralston3,
'ssprk3': SSPRK3, 'SSPRK3': SSPRK3,
'rk4': RK4, 'RK4': RK4,
'ralston4': Ralston4, 'Ralston4': Ralston4,
'rk4_38rule': RK4Rule38, 'RK4Rule38': RK4Rule38,
}


@@ -132,7 +118,7 @@ def register_dde_integrator(name, integrator):
"""
if name in name2method:
raise ValueError(f'"{name}" has been registered in DDE integrators.')
if DDEIntegrator not in integrator.__bases__:
if not issubclass(integrator, DDEIntegrator):
raise ValueError(f'"integrator" must be an instance of {DDEIntegrator.__name__}')
name2method[name] = integrator



+ 401
- 0
brainpy/integrators/fde/Caputo.py View File

@@ -0,0 +1,401 @@
# -*- coding: utf-8 -*-

"""
This module provides numerical methods for integrating Caputo fractional derivative equations.

"""

import jax.numpy as jnp
from jax.experimental.host_callback import id_tap

from brainpy import check
import brainpy.math as bm
from brainpy.errors import UnsupportedError
from brainpy.integrators.constants import DT
from brainpy.integrators.utils import check_inits, format_args
from brainpy.tools.checking import check_integer
from .base import FDEIntegrator
from .generic import register_fde_integrator

__all__ = [
'CaputoEuler',
'CaputoL1Schema',
]


class CaputoEuler(FDEIntegrator):
r"""One-step Euler method for Caputo fractional differential equations.

Given a fractional initial value problem,

.. math::

D_{*}^{\alpha} y(t)=f(t, y(t)), \quad y^{(k)}(0)=y_{0}^{(k)}, \quad k=0,1, \ldots,\lceil\alpha\rceil-1

where the :math:`y_0^{(k)}` ay be arbitrary real numbers and where :math:`\alpha>0`.
:math:`D_{*}^{\alpha}` denotes the differential operator in the sense of Caputo, defined
by

.. math::

D_{*}^{\alpha} z(t)=J^{n-\alpha} D^{n} z(t)

where :math:`n:=\lceil\alpha\rceil` is the smallest integer :math:`\geqslant \alpha`,
Here :math:`D^n` is the usual differential operator of (integer) order :math:`n`,
and for :math:`\mu > 0`, :math:`J^{\mu}` is the Riemann–Liouville integral operator
of order :math:`\mu`, defined by

.. math::

J^{\mu} z(t)=\frac{1}{\Gamma(\mu)} \int_{0}^{t}(t-u)^{\mu-1} z(u) \mathrm{d} u

The one-step Euler method for fractional differential equation is defined as

.. math::

y_{k+1} = y_0 + \frac{1}{\Gamma(\alpha)} \sum_{j=0}^{k} b_{j, k+1} f\left(t_{j}, y_{j}\right).

where

.. math::

b_{j, k+1}=\frac{h^{\alpha}}{\alpha}\left((k+1-j)^{\alpha}-(k-j)^{\alpha}\right).


Examples
--------

>>> import brainpy as bp
>>>
>>> a, b, c = 10, 28, 8 / 3
>>> def lorenz(x, y, z, t):
>>> dx = a * (y - x)
>>> dy = x * (b - z) - y
>>> dz = x * y - c * z
>>> return dx, dy, dz
>>>
>>> duration = 30.
>>> dt = 0.005
>>> inits = [1., 0., 1.]
>>> f = bp.fde.CaputoEuler(lorenz, alpha=0.97, num_step=int(duration / dt), inits=inits)
>>> runner = bp.integrators.IntegratorRunner(f, monitors=list('xyz'), dt=dt, inits=inits)
>>> runner.run(duration)
>>>
>>> import matplotlib.pyplot as plt
>>> plt.plot(runner.mon.x.flatten(), runner.mon.z.flatten())
>>> plt.show()


Parameters
----------
f : callable
The derivative function.
alpha: int, float, jnp.ndarray, bm.ndarray, sequence
The fractional-order of the derivative function. Should be in the range of ``(0., 1.)``.
num_step: int
The total time step of the simulation.
inits: sequence
A sequence of the initial values for variables.
dt: float, int
The numerical precision.
name: str
The integrator name.

References
----------
.. [1] Li, Changpin, and Fanhai Zeng. "The finite difference methods for fractional
ordinary differential equations." Numerical Functional Analysis and
Optimization 34.2 (2013): 149-179.
.. [2] Diethelm, Kai, Neville J. Ford, and Alan D. Freed. "Detailed error analysis
for a fractional Adams method." Numerical algorithms 36.1 (2004): 31-52.
"""

def __init__(self, f, alpha, num_step, inits, dt=None, name=None):
super(CaputoEuler, self).__init__(f=f, alpha=alpha, dt=dt, name=name)

# fractional order
if not jnp.all(jnp.logical_and(self.alpha < 1, self.alpha > 0)):
raise UnsupportedError(f'Only support the fractional order in (0, 1), '
f'but we got {self.alpha}.')

# memory length
check_integer(num_step, 'num_step', min_bound=1, allow_none=False)
self.num_step = num_step

# initial values
self.inits = check_inits(inits, self.variables)

# coefficients
from scipy.special import rgamma
rgamma_alpha = bm.asarray(rgamma(bm.as_numpy(self.alpha)))
ranges = bm.asarray([bm.arange(num_step + 1) for _ in self.variables]).T
coef = rgamma_alpha * bm.diff(bm.power(ranges, self.alpha), axis=0)
self.coef = bm.flip(coef, axis=0)

# variable states
self.f_states = {v: bm.Variable(bm.zeros((num_step,) + self.inits[v].shape))
for v in self.variables}
self.register_implicit_vars(self.f_states)
self.idx = bm.Variable(bm.asarray([1], dtype=bm.int32))

self.set_integral(self._integral_func)

def _check_step(self, args, transform):
dt, t = args
if self.num_step * dt < t:
raise ValueError(f'The maximum number of step is {self.num_step}, '
f'however, the current time {t} require a time '
f'step number {t / dt}.')

def _integral_func(self, *args, **kwargs):
# format arguments
all_args = format_args(args, kwargs, self.arguments)
dt = all_args.pop(DT, self.dt)
if check.is_checking():
id_tap(self._check_step, (dt, all_args['t']))

# derivative values
devs = self.f(**all_args)
if len(self.variables) == 1:
if not isinstance(devs, (bm.ndarray, jnp.ndarray)):
raise ValueError('Derivative values must be a tensor when there '
'is only one variable in the equation.')
devs = {self.variables[0]: devs}
else:
if not isinstance(devs, (tuple, list)):
raise ValueError('Derivative values must be a list/tuple of tensors '
'when there are multiple variables in the equation.')
devs = {var: devs[i] for i, var in enumerate(self.variables)}

# function states
for key in self.variables:
self.f_states[key][self.idx[0]] = devs[key]

# integral results
integrals = []
idx = ((self.num_step - 1 - self.idx) + bm.arange(self.num_step)) % self.num_step
for i, key in enumerate(self.variables):
integral = self.inits[key] + self.coef[idx, i] @ self.f_states[key]
integrals.append(integral * (dt ** self.alpha[i] / self.alpha[i]))
self.idx.value = (self.idx + 1) % self.num_step

# return integrals
if len(self.variables) == 1:
return integrals[0]
else:
return integrals


register_fde_integrator(name='CaputoEuler', integrator=CaputoEuler)


class CaputoABM(FDEIntegrator):
"""Adams-Bashforth-Moulton (ABM) Method for Caputo fractional differential equations.


"""
pass


class CaputoL1Schema(FDEIntegrator):
r"""The L1 scheme method for the numerical approximation of the Caputo
fractional-order derivative equations [3]_.

For the fractional order :math:`0<\alpha<1`, let the fractional derivative of variable
:math:`x(t)` be

.. math::

\frac{d^{\alpha} x}{d t^{\alpha}}=F(x, t)

The Caputo definition of the fractional derivative for variable :math:`x` is

.. math::

\frac{d^{\alpha} x}{d t^{\alpha}}=\frac{1}{\Gamma(1-\alpha)} \int_{0}^{t} \frac{x^{\prime}(u)}{(t-u)^{\alpha}} d u

where :math:`\Gamma` is the Gamma function.

The fractional-order derivative is capable of integrating the activity of the
function over all past activities weighted by a function that follows a power-law.
Using one of the numerical methods, the L1 scheme method [3]_, the numerical
approximation of the fractional-order derivative of :math:`x` is

.. math::

\frac{d^{\alpha} \chi}{d t^{\alpha}} \approx \frac{(d t)^{-\alpha}}{\Gamma(2-\alpha)}\left[\sum_{k=0}^{N-1}\left[x\left(t_{k+1}\right)-
\mathrm{x}\left(t_{k}\right)\right]\left[(N-k)^{1-\alpha}-(N-1-k)^{1-\alpha}\right]\right]

Therefore, the numerical solution of original system is given by

.. math::

x\left(t_{N}\right) \approx d t^{\alpha} \Gamma(2-\alpha) F(x, t)+x\left(t_{N-1}\right)-
\left[\sum_{k=0}^{N-2}\left[x\left(t_{k+1}\right)-x\left(t_{k}\right)\right]\left[(N-k)^{1-\alpha}-(N-1-k)^{1-\alpha}\right]\right]

Hence, the solution of the fractional-order derivative can be described as the
difference between the *Markov term* and the *memory trace*. The *Markov term*
weighted by the gamma function is

.. math::

\text { Markov term }=d t^{\alpha} \Gamma(2-\alpha) F(x, t)+x\left(t_{N-1}\right)

The memory trace (:math:`x`-memory trace since it is related to variable :math:`x`) is

.. math::

\text { Memory trace }=\sum_{k=0}^{N-2}\left[x\left(t_{k+1}\right)-x\left(t_{k}\right)\right]\left[(N-k)^{1-\alpha}-(N-(k+1))^{1-\alpha}\right]

The memory trace integrates all the past activity and captures the long-term
history of the system. For :math:`\alpha=1`, the memory trace is 0 for any
time :math:`t`. When the fractional order :math:`\alpha` is decreased from 1,
the memory trace non-linearly increases from 0, and its dynamics strongly
depends on time. Thus, the fractional order dynamics strongly deviates
from the first order dynamics.


Examples
--------

>>> import brainpy as bp
>>>
>>> a, b, c = 10, 28, 8 / 3
>>> def lorenz(x, y, z, t):
>>> dx = a * (y - x)
>>> dy = x * (b - z) - y
>>> dz = x * y - c * z
>>> return dx, dy, dz
>>>
>>> duration = 30.
>>> dt = 0.005
>>> inits = [1., 0., 1.]
>>> f = bp.fde.CaputoL1Schema(lorenz, alpha=0.99, num_step=int(duration / dt), inits=inits)
>>> runner = bp.integrators.IntegratorRunner(f, monitors=list('xz'), dt=dt, inits=inits)
>>> runner.run(duration)
>>>
>>> import matplotlib.pyplot as plt
>>> plt.plot(runner.mon.x.flatten(), runner.mon.z.flatten())
>>> plt.show()


Parameters
----------
f : callable
The derivative function.
alpha: int, float, jnp.ndarray, bm.ndarray, sequence
The fractional-order of the derivative function. Should be in the range of ``(0., 1.]``.
num_step: int
The total time step of the simulation.
inits: sequence
A sequence of the initial values for variables.
dt: float, int
The numerical precision.
name: str
The integrator name.

References
----------
.. [3] Oldham, K., & Spanier, J. (1974). The fractional calculus theory
and applications of differentiation and integration to arbitrary
order. Elsevier.
"""

def __init__(self, f, alpha, num_step, inits, dt=None, name=None):
super(CaputoL1Schema, self).__init__(f=f, alpha=alpha, dt=dt, name=name)

# fractional order
if not jnp.all(jnp.logical_and(self.alpha <= 1, self.alpha > 0)):
raise UnsupportedError(f'Only support the fractional order in (0, 1), '
f'but we got {self.alpha}.')
from scipy.special import gamma
self.gamma_alpha = bm.asarray(gamma(bm.as_numpy(2 - self.alpha)))

# memory length
check_integer(num_step, 'num_step', min_bound=1, allow_none=False)
self.num_step = num_step

# initial values
inits = check_inits(inits, self.variables)
self.inits = {v: bm.Variable(inits[v]) for v in self.variables}
self.register_implicit_vars(self.inits)

# coefficients
ranges = bm.asarray([bm.arange(1, num_step + 2) for _ in self.variables]).T
coef = bm.diff(bm.power(ranges, 1 - self.alpha), axis=0)
self.coef = bm.flip(coef, axis=0)

# variable states
self.diff_states = {v + "_diff": bm.Variable(bm.zeros((num_step,) + self.inits[v].shape))
for v in self.variables}
self.register_implicit_vars(self.diff_states)
self.idx = bm.Variable(bm.asarray([self.num_step - 1], dtype=bm.int32))

# integral function
self.set_integral(self._integral_func)

def hists(self, var=None, numpy=True):
if var is None:
hists_ = {k: bm.vstack([self.inits[k], self.diff_states[k + '_diff']])
for k in self.variables}
hists_ = {k: bm.cumsum(v, axis=0) for k, v in hists_.items()}
if numpy:
hists_ = {k: v.numpy() for k, v in hists_}
return hists_
else:
assert var in self.variables, (f'"{var}" is not defined in equation '
f'variables: {self.variables}')
hists_ = bm.vstack([self.inits[var], self.diff_states[var + '_diff']])
hists_ = bm.cumsum(hists_, axis=0)
if numpy:
hists_ = hists_.numpy()
return hists_

def _check_step(self, args, transform):
dt, t = args
if self.num_step * dt < t:
raise ValueError(f'The maximum number of step is {self.num_step}, '
f'however, the current time {t} require a time '
f'step number {t / dt}.')

def _integral_func(self, *args, **kwargs):
# format arguments
all_args = format_args(args, kwargs, self.arguments)
dt = all_args.pop(DT, self.dt)
if check.is_checking():
id_tap(self._check_step, (dt, all_args['t']))

# derivative values
devs = self.f(**all_args)
if len(self.variables) == 1:
if not isinstance(devs, (bm.ndarray, jnp.ndarray)):
raise ValueError('Derivative values must be a tensor when there '
'is only one variable in the equation.')
devs = {self.variables[0]: devs}
else:
if not isinstance(devs, (tuple, list)):
raise ValueError('Derivative values must be a list/tuple of tensors '
'when there are multiple variables in the equation.')
devs = {var: devs[i] for i, var in enumerate(self.variables)}

# integral results
integrals = []
idx = ((self.num_step - 1 - self.idx) + bm.arange(self.num_step)) % self.num_step
for i, key in enumerate(self.variables):
self.diff_states[key + '_diff'][self.idx[0]] = all_args[key] - self.inits[key]
self.inits[key].value = all_args[key]
markov_term = dt ** self.alpha[i] * self.gamma_alpha[i] * devs[key] + all_args[key]
memory_trace = self.coef[idx, i] @ self.diff_states[key + '_diff']
integral = markov_term - memory_trace
integrals.append(integral)
self.idx.value = (self.idx + 1) % self.num_step

# return integrals
if len(self.variables) == 1:
return integrals[0]
else:
return integrals


register_fde_integrator(name='CaputoL1', integrator=CaputoL1Schema)
register_fde_integrator(name='CaputoL1Schema', integrator=CaputoL1Schema)

+ 190
- 0
brainpy/integrators/fde/GL.py View File

@@ -0,0 +1,190 @@
# -*- coding: utf-8 -*-

"""
This module provides numerical solvers for Grünwald–Letnikov derivative FDEs.
"""

import jax.numpy as jnp

import brainpy.math as bm
from brainpy.errors import UnsupportedError
from brainpy.integrators.constants import DT
from brainpy.tools.checking import check_integer
from .base import FDEIntegrator
from brainpy.integrators.utils import check_inits, format_args

__all__ = [
'GLShortMemory'
]


class GLShortMemory(FDEIntegrator):
r"""Efficient Computation of the Short-Memory Principle in Grünwald-Letnikov Method [1]_.

According to the explicit numerical approximation of Grünwald-Letnikov, the
fractional-order derivative :math:`q` for a discrete function :math:`f(t_K)`
can be described as follows:

.. math::

{{}_{k-\frac{L_{m}}{h}}D_{t_{k}}^{q}}f(t_{k})\approx h^{-q}
\sum\limits_{j=0}^{k}C_{j}^{q}f(t_{k-j})

where :math:`L_{m}` is the memory lenght, :math:`h` is the integration step size,
and :math:`C_{j}^{q}` are the binomial coefficients which are calculated recursively with

.. math::

C_{0}^{q}=1,\ C_{j}^{q}=\left(1- \frac{1+q}{j}\right)C_{j-1}^{q},\ j=1,2, \ldots k.

Then, the numerical solution for a fractional-order differential equation (FODE) expressed
in the form

.. math::

D_{t_{k}}^{q}x(t_{k})=f(x(t_{k}))

can be obtained by

.. math::

x(t_{k})=f(x(t_{k-1}))h^{q}- \sum\limits_{j=1}^{k}C_{j}^{q}x(t_{k-j}).

for :math:`0 < q < 1`. The above expression requires infinity memory length
for numerical solution since the summation term depends on the discritized
time :math:`t_k`. This implies relatively high simulation times.

To reduce the computational time, the upper bound of summation needs to be modified by
:math:`k=v`, where

.. math::

v=\begin{cases} k, & k\leq M,\\ L_{m}, & k > M. \end{cases}

This is known as the short-memory principle, where :math:`M`
is the memory window with a width defined by :math:`M=\frac{L_{m}}{h}`.
As was reported in [2]_, the accuracy increases by increaing the width of memory window.

Examples
--------

>>> import brainpy as bp
>>>
>>> a, b, c = 10, 28, 8 / 3
>>> def lorenz(x, y, z, t):
>>> dx = a * (y - x)
>>> dy = x * (b - z) - y
>>> dz = x * y - c * z
>>> return dx, dy, dz
>>>
>>> integral = bp.fde.GLShortMemory(lorenz,
>>> alpha=0.96,
>>> num_memory=500,
>>> inits=[1., 0., 1.])
>>> runner = bp.integrators.IntegratorRunner(integral,
>>> monitors=list('xyz'),
>>> inits=[1., 0., 1.],
>>> dt=0.005)
>>> runner.run(100.)
>>>
>>> import matplotlib.pyplot as plt
>>> plt.plot(runner.mon.x.flatten(), runner.mon.z.flatten())
>>> plt.show()


Parameters
----------
f : callable
The derivative function.
alpha: int, float, jnp.ndarray, bm.ndarray, sequence
The fractional-order of the derivative function. Should be in the range of ``(0., 1.)``.
num_memory: int
The length of the short memory.
inits: sequence
A sequence of the initial values for variables.
dt: float, int
The numerical precision.
name: str
The integrator name.

References
----------
.. [1] Clemente-López, D., et al. "Efficient computation of the
Grünwald-Letnikov method for arm-based implementations of
fractional-order chaotic systems." 2019 8th International
Conference on Modern Circuits and Systems Technologies (MOCAST). IEEE, 2019.
.. [2] M. F. Tolba, A. M. AbdelAty, N. S. Soliman, L. A. Said, A. H.
Madian, A. T. Azar, et al., "FPGA implementation of two fractional
order chaotic systems", International Journal of Electronics and
Communications, vol. 78, pp. 162-172, 2017.
"""

def __init__(self, f, alpha, num_memory, inits, dt=None, name=None):
super(GLShortMemory, self).__init__(f=f, alpha=alpha, dt=dt, name=name)

# fractional order
if not jnp.all(jnp.logical_and(self.alpha <= 1, self.alpha > 0)):
raise UnsupportedError(f'Only support the fractional order in (0, 1), '
f'but we got {self.alpha}.')

# memory length
check_integer(num_memory, 'num_memory', min_bound=1, allow_none=False)
self.num_memory = num_memory

# initial values
self.inits = check_inits(inits, self.variables)

# delays
self.delays = {}
for key, val in self.inits.items():
delay = bm.Variable(bm.zeros((self.num_memory,) + val.shape, dtype=bm.float_))
delay[0] = val
self.delays[key] = delay
self._idx = bm.Variable(bm.asarray([1], dtype=bm.int32))
self.register_implicit_vars(self.delays)

# binomial coefficients
bc = (1 - (1 + self.alpha.reshape((-1, 1))) / jnp.arange(1, num_memory + 1))
bc = jnp.cumprod(jnp.vstack([jnp.ones_like(self.alpha), bc.T]), axis=0)
self._binomial_coef = jnp.flip(bc[1:], axis=0)

# integral function
self.set_integral(self._integral_func)

@property
def binomial_coef(self):
return bm.as_numpy(jnp.flip(self._binomial_coef, axis=0))

def _integral_func(self, *args, **kwargs):
# format arguments
all_args = format_args(args, kwargs, self.arguments)
dt = all_args.pop(DT, self.dt)

# derivative values
devs = self.f(**all_args)
if len(self.variables) == 1:
if not isinstance(devs, (bm.ndarray, jnp.ndarray)):
raise ValueError('Derivative values must be a tensor when there '
'is only one variable in the equation.')
devs = {self.variables[0]: devs}
else:
if not isinstance(devs, (tuple, list)):
raise ValueError('Derivative values must be a list/tuple of tensors '
'when there are multiple variables in the equation.')
devs = {var: devs[i] for i, var in enumerate(self.variables)}

# integral results
integrals = []
idx = (self._idx + bm.arange(self.num_memory)) % self.num_memory
for i, var in enumerate(self.variables):
summation = self._binomial_coef[:, i] @ self.delays[var][idx]
integral = (dt ** self.alpha[i]) * devs[var] - summation
self.delays[var][self._idx[0]] = integral
integrals.append(integral)
self._idx.value = (self._idx + 1) % self.num_memory

# return integrals
if len(self.variables) == 1:
return integrals[0]
else:
return integrals

+ 0
- 95
brainpy/integrators/fde/RL.py View File

@@ -1,95 +0,0 @@
# -*- coding: utf-8 -*-

import jax.numpy as jnp
from jax import vmap
from jax.lax import cond

from brainpy.math.special import Gamma
from brainpy.tools.checking import check_float

__all__ = [
'RL',
]


def RLcoeffs(index_k, index_j, alpha):
"""Calculates coefficients for the RL differintegral operator.

see Baleanu, D., Diethelm, K., Scalas, E., and Trujillo, J.J. (2012). Fractional
Calculus: Models and Numerical Methods. World Scientific.
"""

def f1(x):
k, j = x
return ((k - 1) ** (1 - alpha) -
(k + alpha - 1) * k ** -alpha)

def f2(x):
k, j = x
return cond(k == j, lambda _: 1., f3, x)

def f3(x):
k, j = x
return ((k - j + 1) ** (1 - alpha) +
(k - j - 1) ** (1 - alpha) -
2 * (k - j) ** (1 - alpha))

return cond(index_j == 0, f1, f2, (index_k, index_j))


def RLmatrix(alpha, N):
""" Define the coefficient matrix for the RL algorithm. """
ij = jnp.tril_indices(N, -1)
coeff = vmap(RLcoeffs, in_axes=(0, 0, None))(ij[0], ij[1], alpha)
mat = jnp.zeros((N, N)).at[ij].set(coeff)
diagonal = jnp.arange(N)
mat = mat.at[diagonal, diagonal].set(1.)
return mat / Gamma(2 - alpha)


def RL(alpha, f, domain_start=0.0, domain_end=1.0, dt=0.01):
""" Calculate the RL algorithm using a trapezoid rule over
an array of function values.

Examples
--------

>>> RL_sqrt = RL(0.5, lambda x: x ** 0.5)
>>> RL_poly = RL(0.5, lambda x: x**2 - 1, 0., 1., 100)

Parameters
----------
alpha : float
The order of the differintegral to be computed.
f : function
This is the function that is to be differintegrated.
domain_start : float, int
The left-endpoint of the function domain. Default value is 0.
domain_end : float, int
The right-endpoint of the function domain; the point at which the
differintegral is being evaluated. Default value is 1.
dt : float, int
The number of points in the domain. Default value is 100.

Returns
-------
RL : float 1d-array
Each element of the array is the RL differintegral evaluated at the
corresponding function array index.
"""
# checking
assert domain_start < domain_end, ('"domain_start" should be lower than "domain_end", '
f'while we got {domain_start} >= {domain_end}')
check_float(alpha, 'alpha', allow_none=False)
check_float(domain_start, 'domain_start', allow_none=False)
check_float(domain_end, 'domain_start', allow_none=False)
check_float(dt, 'dt', allow_none=False)
# computing
points = jnp.arange(domain_start, domain_end, dt)
f_values = vmap(f)(points)
# Calculate the RL differintegral.
D = RLmatrix(alpha, points.shape[0])
RL = dt ** -alpha * jnp.dot(D, f_values)
return RL



+ 7
- 0
brainpy/integrators/fde/__init__.py View File

@@ -1 +1,8 @@
# -*- coding: utf-8 -*-

from .base import *
from .generic import *
from .GL import *
from .Caputo import *



+ 76
- 2
brainpy/integrators/fde/base.py View File

@@ -1,8 +1,82 @@
# -*- coding: utf-8 -*-

from ..base import Integrator
import abc
from typing import Union, Callable

import jax.numpy as jnp
import brainpy.math as bm
from brainpy.integrators.base import Integrator
from brainpy.integrators.utils import get_args
from brainpy.errors import UnsupportedError


__all__ = [
'FDEIntegrator'
]


class FDEIntegrator(Integrator):
pass
"""Numerical integrator for fractional differential equations (FEDs).

Parameters
----------
f : callable
The derivative function.
alpha: int, float, jnp.ndarray, bm.ndarray, sequence
The fractional-order of the derivative function.
dt: float, int
The numerical precision.
name: str
The integrator name.
"""

"""The fraction order for each variable."""
alpha: jnp.ndarray

"""The numerical integration precision."""
dt: Union[float, int]

"""The fraction derivative function."""
f: Callable

def __init__(self, f, alpha, dt=None, name=None):
dt = bm.get_dt() if dt is None else dt
parses = get_args(f)
variables = parses[0] # variable names, (before 't')
parameters = parses[1] # parameter names, (after 't')
arguments = parses[2] # function arguments

# super initialization
super(FDEIntegrator, self).__init__(name=name,
variables=variables,
parameters=parameters,
arguments=arguments,
dt=dt)

# derivative function
self.f = f

# fractional-order
if isinstance(alpha, (int, float)):
alpha = jnp.ones(len(self.variables)) * alpha
elif isinstance(alpha, (jnp.ndarray, bm.ndarray)):
alpha = bm.as_device_array(alpha)
elif isinstance(alpha, (list, tuple)):
for a in alpha:
assert isinstance(a, (float, int)), (f'Must be a tuple/list of int/float, '
f'but we got {type(a)}: {a}')
alpha = jnp.asarray(alpha)
else:
raise UnsupportedError(f'Do not support {type(alpha)}, please '
f'set fractional-order as number/tuple/list/tensor.')
if len(alpha) != len(self.variables):
raise ValueError(f'There are {len(self.variables)} variables, '
f'while we only got {len(alpha)} fractional-order '
f'settings: {alpha}')
self.alpha = alpha

@abc.abstractmethod
def build(self):
raise NotImplementedError('Must implement how to build your step function.')



+ 92
- 0
brainpy/integrators/fde/generic.py View File

@@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-

from .base import FDEIntegrator

__all__ = [
'fdeint',
'set_default_fdeint',
'get_default_fdeint',
'register_fde_integrator',
'get_supported_methods',
]

name2method = {
}

_DEFAULT_DDE_METHOD = 'CaputoL1'


def fdeint(f=None, method='CaputoL1', **kwargs):
"""Numerical integration for FDEs.

Parameters
----------
f : callable, function
The derivative function.
method : str
The shortcut name of the numerical integrator.

Returns
-------
integral : FDEIntegrator
The numerical solver of `f`.
"""
method = _DEFAULT_DDE_METHOD if method is None else method
if method not in name2method:
raise ValueError(f'Unknown FDE numerical method "{method}". Currently '
f'BrainPy only support: {list(name2method.keys())}')

if f is None:
return lambda f: name2method[method](f, **kwargs)
else:
return name2method[method](f, **kwargs)


def set_default_fdeint(method):
"""Set the default ODE numerical integrator method for differential equations.

Parameters
----------
method : str, callable
Numerical integrator method.
"""
if not isinstance(method, str):
raise ValueError(f'Only support string, not {type(method)}.')
if method not in name2method:
raise ValueError(f'Unsupported ODE_INT numerical method: {method}.')

global _DEFAULT_DDE_METHOD
_DEFAULT_ODE_METHOD = method


def get_default_fdeint():
"""Get the default ODE numerical integrator method.

Returns
-------
method : str
The default numerical integrator method.
"""
return _DEFAULT_DDE_METHOD


def register_fde_integrator(name, integrator):
"""Register a new ODE integrator.

Parameters
----------
name: ste
The integrator name.
integrator: type
The integrator.
"""
if name in name2method:
raise ValueError(f'"{name}" has been registered in ODE integrators.')
if not issubclass(integrator, FDEIntegrator):
raise ValueError(f'"integrator" must be an instance of {FDEIntegrator.__name__}')
name2method[name] = integrator


def get_supported_methods():
"""Get all supported numerical methods for DDEs."""
return list(name2method.keys())

+ 33
- 0
brainpy/integrators/fde/tests/test_Caputo.py View File

@@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-


import unittest

import numpy as np

import brainpy as bp


class TestCaputoL1(unittest.TestCase):
def test1(self):
bp.math.enable_x64()
alpha = 0.9
intg = bp.fde.CaputoL1Schema(lambda a, t: a,
alpha=alpha,
num_step=10,
inits=[1., ])
for N in [2, 3, 4, 5, 6, 7, 8]:
diff = np.random.rand(N - 1, 1)
memory_trace = 0
for i in range(N - 1):
c = (N - i) ** (1 - alpha) - (N - i - 1) ** (1 - alpha)
memory_trace += c * diff[i]

intg.idx[0] = N - 1
intg.diff_states['a_diff'][:N - 1] = bp.math.asarray(diff)
idx = ((intg.num_step - intg.idx) + np.arange(intg.num_step)) % intg.num_step
memory_trace2 = intg.coef[idx, 0] @ intg.diff_states['a_diff']

print()
print(memory_trace[0], )
print(memory_trace2[0], bp.math.array_equal(memory_trace[0], memory_trace2[0]))

+ 32
- 0
brainpy/integrators/fde/tests/test_GL.py View File

@@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-


import unittest
import matplotlib.pyplot as plt
import brainpy as bp


class TestGLShortMemory(unittest.TestCase):
def test_lorenz(self):
a, b, c = 10, 28, 8 / 3

def lorenz(x, y, z, t):
dx = a * (y - x)
dy = x * (b - z) - y
dz = x * y - c * z
return dx, dy, dz

integral = bp.fde.GLShortMemory(lorenz,
alpha=0.99,
num_memory=500,
inits=[1., 0., 1.])
runner = bp.integrators.IntegratorRunner(integral,
monitors=list('xyz'),
inits=[1., 0., 1.],
dt=0.005)
runner.run(100.)

plt.plot(runner.mon.x.flatten(), runner.mon.z.flatten())
plt.show(block=False)



+ 0
- 16
brainpy/integrators/fde/tests/test_RL.py View File

@@ -1,16 +0,0 @@
# -*- coding: utf-8 -*-

import unittest
from brainpy.integrators.fde.RL import RLmatrix
import brainpy.math as bm


class TestRLAlgorithm(unittest.TestCase):
def test_RL_matrix_shape(self):
bm.enable_x64()
print()
print(RLmatrix(0.4, 5))
self.assertTrue(RLmatrix(0.4, 10).shape == (10, 10))
self.assertTrue(RLmatrix(1.2, 5).shape == (5, 5))



+ 1
- 1
brainpy/integrators/joint_eq.py View File

@@ -153,7 +153,7 @@ class JointEq(object):
for par in args[len(vars) + 1:]:
if (par not in vars_in_eqs) and (par not in all_arg_pars) and (par not in all_kwarg_pars):
all_arg_pars.append(par)
for key, value in kwargs.values():
for key, value in kwargs.items():
if key in all_kwarg_pars and value != all_kwarg_pars[key]:
raise errors.DiffEqError(f'We got two different default value of "{key}": '
f'{all_kwarg_pars[key]} != {value}')


+ 39
- 0
brainpy/integrators/ode/adaptive_rk.py View File

@@ -58,6 +58,7 @@ from brainpy import errors
from brainpy.integrators import constants as C, utils
from brainpy.integrators.ode import common
from brainpy.integrators.ode.base import ODEIntegrator
from .generic import register_ode_integrator

__all__ = [
'AdaptiveRKIntegrator',
@@ -239,6 +240,9 @@ class RKF12(AdaptiveRKIntegrator):
C = [0, 0.5, 1]


register_ode_integrator('rkf12', RKF12)


class RKF45(AdaptiveRKIntegrator):
r"""The Runge–Kutta–Fehlberg method for ODEs.

@@ -285,6 +289,9 @@ class RKF45(AdaptiveRKIntegrator):
C = [0, 0.25, 0.375, '12/13', 1, '1/3']


register_ode_integrator('rkf45', RKF45)


class DormandPrince(AdaptiveRKIntegrator):
r"""The Dormand–Prince method for ODEs.

@@ -336,6 +343,9 @@ class DormandPrince(AdaptiveRKIntegrator):
C = [0, 0.2, 0.3, 0.8, '8/9', 1, 1]


register_ode_integrator('rkdp', DormandPrince)


class CashKarp(AdaptiveRKIntegrator):
r"""The Cash–Karp method for ODEs.

@@ -384,6 +394,9 @@ class CashKarp(AdaptiveRKIntegrator):
C = [0, 0.2, 0.3, 0.6, 1, 0.875]


register_ode_integrator('ck', CashKarp)


class BogackiShampine(AdaptiveRKIntegrator):
r"""The Bogacki–Shampine method for ODEs.

@@ -427,6 +440,9 @@ class BogackiShampine(AdaptiveRKIntegrator):
C = [0, 0.5, 0.75, 1]


register_ode_integrator('bs', BogackiShampine)


class HeunEuler(AdaptiveRKIntegrator):
r"""The Heun–Euler method for ODEs.

@@ -457,6 +473,9 @@ class HeunEuler(AdaptiveRKIntegrator):
C = [0, 1]


register_ode_integrator('heun_euler', HeunEuler)


class DOP853(AdaptiveRKIntegrator):
# def DOP853(f=None, tol=None, adaptive=None, dt=None, show_code=None, each_var_is_scalar=None):
r"""The DOP853 method for ODEs.
@@ -473,3 +492,23 @@ class DOP853(AdaptiveRKIntegrator):
.. [2] http://www.unige.ch/~hairer/software.html
"""
pass


class BoSh3(AdaptiveRKIntegrator):
"""
Bogacki--Shampine's 3/2 method.

3rd order explicit Runge--Kutta method. Has an embedded 2nd order method for
adaptive step sizing.

"""
A = [(),
(0.5,),
(0.0, 0.75),
('2/9', '1/3', '4/9')]
B1 = ['2/9', '1/3', '4/9', 0.0]
B2 = ['-5/72', 1 / 12, '1/9', '-1/8']
C = [0., 0.5, 0.75, 1.0]


register_ode_integrator('BoSh3', BoSh3)

+ 12
- 1
brainpy/integrators/ode/base.py View File

@@ -21,6 +21,17 @@ def f_names(f):

class ODEIntegrator(Integrator):
"""Numerical Integrator for Ordinary Differential Equations (ODEs).

Parameters
----------
f : callable
The derivative function.
var_type: str
The type for each variable.
dt: float, int
The numerical precision.
name: str
The integrator name.
"""

def __init__(self, f, var_type=None, dt=None, name=None, show_code=False):
@@ -29,7 +40,7 @@ class ODEIntegrator(Integrator):
parses = utils.get_args(f)
variables = parses[0] # variable names, (before 't')
parameters = parses[1] # parameter names, (after 't')
arguments = parses[2] # function arguments
arguments = parses[2] # function arguments

# super initialization
super(ODEIntegrator, self).__init__(name=name,


+ 37
- 0
brainpy/integrators/ode/explicit_rk.py View File

@@ -70,6 +70,7 @@ More details please see references [2]_ [3]_ [4]_.
from brainpy.integrators import constants as C, utils
from brainpy.integrators.ode import common
from brainpy.integrators.ode.base import ODEIntegrator
from .generic import register_ode_integrator

__all__ = [
'ExplicitRKIntegrator',
@@ -247,6 +248,9 @@ class Euler(ExplicitRKIntegrator):
C = [0]


register_ode_integrator('euler', Euler)


class MidPoint(ExplicitRKIntegrator):
r"""Explicit midpoint method for ODEs.

@@ -341,6 +345,9 @@ class MidPoint(ExplicitRKIntegrator):
C = [0, 0.5]


register_ode_integrator('midpoint', MidPoint)


class Heun2(ExplicitRKIntegrator):
r"""Heun's method for ODEs.

@@ -406,6 +413,9 @@ class Heun2(ExplicitRKIntegrator):
C = [0, 1]


register_ode_integrator('heun2', Heun2)


class Ralston2(ExplicitRKIntegrator):
r"""Ralston's method for ODEs.

@@ -437,6 +447,9 @@ class Ralston2(ExplicitRKIntegrator):
C = [0, '2/3']


register_ode_integrator('ralston2', Ralston2)


class RK2(ExplicitRKIntegrator):
r"""Generic second order Runge-Kutta method for ODEs.

@@ -560,6 +573,9 @@ class RK2(ExplicitRKIntegrator):
super(RK2, self).__init__(f=f, var_type=var_type, dt=dt, name=name, show_code=show_code)


register_ode_integrator('rk2', RK2)


class RK3(ExplicitRKIntegrator):
r"""Classical third-order Runge-Kutta method for ODEs.

@@ -598,6 +614,9 @@ class RK3(ExplicitRKIntegrator):
C = [0, 0.5, 1]


register_ode_integrator('rk3', RK3)


class Heun3(ExplicitRKIntegrator):
r"""Heun's third-order method for ODEs.

@@ -622,6 +641,9 @@ class Heun3(ExplicitRKIntegrator):
C = [0, '1/3', '2/3']


register_ode_integrator('heun3', Heun3)


class Ralston3(ExplicitRKIntegrator):
r"""Ralston's third-order method for ODEs.

@@ -651,6 +673,9 @@ class Ralston3(ExplicitRKIntegrator):
C = [0, 0.5, 0.75]


register_ode_integrator('ralston3', Ralston3)


class SSPRK3(ExplicitRKIntegrator):
r"""Third-order Strong Stability Preserving Runge-Kutta (SSPRK3).

@@ -674,6 +699,9 @@ class SSPRK3(ExplicitRKIntegrator):
C = [0, 1, 0.5]


register_ode_integrator('ssprk3', SSPRK3)


class RK4(ExplicitRKIntegrator):
r"""Classical fourth-order Runge-Kutta method for ODEs.

@@ -741,6 +769,9 @@ class RK4(ExplicitRKIntegrator):
C = [0, 0.5, 0.5, 1]


register_ode_integrator('rk4', RK4)


class Ralston4(ExplicitRKIntegrator):
r"""Ralston's fourth-order method for ODEs.

@@ -772,6 +803,9 @@ class Ralston4(ExplicitRKIntegrator):
C = [0, .4, .45573725, 1]


register_ode_integrator('ralston4', Ralston4)


class RK4Rule38(ExplicitRKIntegrator):
r"""3/8-rule fourth-order method for ODEs.

@@ -811,3 +845,6 @@ class RK4Rule38(ExplicitRKIntegrator):
A = [(), ('1/3',), ('-1/3', '1'), (1, -1, 1)]
B = [0.125, 0.375, 0.375, 0.125]
C = [0, '1/3', '2/3', 1]


register_ode_integrator('rk4_38rule', RK4Rule38)

+ 9
- 0
brainpy/integrators/ode/exponential.py View File

@@ -113,6 +113,7 @@ from brainpy.base.collector import Collector
from brainpy.integrators import constants as C, utils, joint_eq
from brainpy.integrators.analysis_by_ast import separate_variables
from brainpy.integrators.ode.base import ODEIntegrator
from .generic import register_ode_integrator

try:
import sympy
@@ -506,6 +507,10 @@ class ExponentialEuler(ODEIntegrator):
return s_df_part


register_ode_integrator('exponential_euler', ExponentialEuler)
register_ode_integrator('exp_euler', ExponentialEuler)


class ExpEulerAuto(ODEIntegrator):
"""Exponential Euler method using automatic differentiation.

@@ -762,3 +767,7 @@ class ExpEulerAuto(ODEIntegrator):
return args[0] + dt * phi * derivative

return [(integral, vars, pars), ]


register_ode_integrator('exp_euler_auto', ExpEulerAuto)
register_ode_integrator('exp_auto', ExpEulerAuto)

+ 1
- 29
brainpy/integrators/ode/generic.py View File

@@ -1,9 +1,6 @@
# -*- coding: utf-8 -*-

from .base import ODEIntegrator
from .adaptive_rk import *
from .explicit_rk import *
from .exponential import *

__all__ = [
'odeint',
@@ -14,31 +11,6 @@ __all__ = [
]

name2method = {
# explicit RK
'euler': Euler, 'Euler': Euler,
'midpoint': MidPoint, 'MidPoint': MidPoint,
'heun2': Heun2, 'Heun2': Heun2,
'ralston2': Ralston2, 'Ralston2': Ralston2,
'rk2': RK2, 'RK2': RK2,
'rk3': RK3, 'RK3': RK3,
'heun3': Heun3, 'Heun3': Heun3,
'ralston3': Ralston3, 'Ralston3': Ralston3,
'ssprk3': SSPRK3, 'SSPRK3': SSPRK3,
'rk4': RK4, 'RK4': RK4,
'ralston4': Ralston4, 'Ralston4': Ralston4,
'rk4_38rule': RK4Rule38, 'RK4Rule38': RK4Rule38,

# adaptive RK
'rkf12': RKF12, 'RKF12': RKF12,
'rkf45': RKF45, 'RKF45': RKF45,
'rkdp': DormandPrince, 'dp': DormandPrince, 'DormandPrince': DormandPrince,
'ck': CashKarp, 'CashKarp': CashKarp,
'bs': BogackiShampine, 'BogackiShampine': BogackiShampine,
'heun_euler': HeunEuler, 'HeunEuler': HeunEuler,

# exponential integrators
'exponential_euler': ExponentialEuler, 'exp_euler': ExponentialEuler, 'ExponentialEuler': ExponentialEuler,
'exp_euler_auto': ExpEulerAuto, 'exp_auto': ExpEulerAuto, 'ExpEulerAuto': ExpEulerAuto,
}

_DEFAULT_DDE_METHOD = 'euler'
@@ -134,7 +106,7 @@ def register_ode_integrator(name, integrator):
"""
if name in name2method:
raise ValueError(f'"{name}" has been registered in ODE integrators.')
if ODEIntegrator not in integrator.__bases__:
if not issubclass(integrator, ODEIntegrator):
raise ValueError(f'"integrator" must be an instance of {ODEIntegrator.__name__}')
name2method[name] = integrator



+ 1
- 1
brainpy/integrators/runner.py View File

@@ -93,7 +93,7 @@ class IntegratorRunner(Runner):
>>> dt = 0.01; beta=2.; gamma=1.; tau=2.; n=9.65
>>> mg_eq = lambda x, t, xdelay: (beta * xdelay(t - tau) / (1 + xdelay(t - tau) ** n)
>>> - gamma * x)
>>> xdelay = bm.FixedLenDelay(1, delay_len=tau, dt=dt, before_t0=lambda t: 1.2)
>>> xdelay = bm.TimeDelay(bm.asarray([1.2]), delay_len=tau, dt=dt, before_t0=lambda t: 1.2)
>>> integral = bp.ddeint(mg_eq, method='rk4', state_delays={'x': xdelay})
>>> runner = bp.integrators.IntegratorRunner(
>>> integral,


+ 1
- 12
brainpy/integrators/sde/generic.py View File

@@ -1,8 +1,6 @@
# -*- coding: utf-8 -*-

from .base import SDEIntegrator
from .normal import *
from .srk_scalar import *

__all__ = [
'sdeint',
@@ -13,15 +11,6 @@ __all__ = [
]

name2method = {
'euler': Euler, 'Euler': Euler,
'heun': Heun, 'Heun': Heun,
'milstein': Milstein, 'Milstein': Milstein,
'exponential_euler': ExponentialEuler, 'exp_euler': ExponentialEuler, 'ExponentialEuler': ExponentialEuler,

# RK methods
'srk1w1': SRK1W1, 'SRK1W1': SRK1W1,
'srk2w1': SRK2W1, 'SRK2W1': SRK2W1,
'klpl': KlPl, 'KlPl': KlPl,
}

_DEFAULT_SDE_METHOD = 'euler'
@@ -98,7 +87,7 @@ def register_sde_integrator(name, integrator):
"""
if name in name2method:
raise ValueError(f'"{name}" has been registered in SDE integrators.')
if SDEIntegrator not in integrator.__bases__:
if not issubclass(integrator, SDEIntegrator):
raise ValueError(f'"integrator" must be an instance of {SDEIntegrator.__name__}')
name2method[name] = integrator



+ 14
- 0
brainpy/integrators/sde/normal.py View File

@@ -6,6 +6,7 @@ from brainpy import errors, math
from brainpy.integrators import constants, utils
from brainpy.integrators.analysis_by_ast import separate_variables
from brainpy.integrators.sde.base import SDEIntegrator
from .generic import register_sde_integrator

try:
import sympy
@@ -142,6 +143,9 @@ class Euler(SDEIntegrator):
func_name=self.func_name)


register_sde_integrator('euler', Euler)


class Heun(Euler):
def __init__(self, f, g, dt=None, name=None, show_code=False,
var_type=None, intg_type=None, wiener_type=None):
@@ -154,6 +158,9 @@ class Heun(Euler):
self.build()


register_sde_integrator('heun', Heun)


class Milstein(SDEIntegrator):
def __init__(self, f, g, dt=None, name=None, show_code=False,
var_type=None, intg_type=None, wiener_type=None):
@@ -238,6 +245,9 @@ class Milstein(SDEIntegrator):
func_name=self.func_name)


register_sde_integrator('milstein', Milstein)


class ExponentialEuler(SDEIntegrator):
r"""First order, explicit exponential Euler method.

@@ -399,3 +409,7 @@ class ExponentialEuler(SDEIntegrator):
if hasattr(self.derivative[constants.F], '__self__'):
host = self.derivative[constants.F].__self__
self.integral = self.integral.__get__(host, host.__class__)


register_sde_integrator('exponential_euler', ExponentialEuler)
register_sde_integrator('exp_euler', ExponentialEuler)

+ 11
- 1
brainpy/integrators/sde/srk_scalar.py View File

@@ -2,6 +2,7 @@

from brainpy.integrators import constants, utils
from brainpy.integrators.sde.base import SDEIntegrator
from .generic import register_sde_integrator

__all__ = [
'SRK1W1',
@@ -175,6 +176,9 @@ class SRK1W1(SDEIntegrator):
func_name=self.func_name)


register_sde_integrator('srk1w1', SRK1W1)


class SRK2W1(SDEIntegrator):
r"""Order 1.5 Strong SRK Methods for SDEs with Scalar Noise.

@@ -315,6 +319,9 @@ class SRK2W1(SDEIntegrator):
func_name=self.func_name)


register_sde_integrator('srk2w1', SRK2W1)


class KlPl(SDEIntegrator):
def __init__(self, f, g, dt=None, name=None, show_code=False,
var_type=None, intg_type=None, wiener_type=None):
@@ -354,7 +361,7 @@ class KlPl(SDEIntegrator):
self.code_lines.append(f' {var}_g1 = -{var}_I1 + {var}_I11/dt_sqrt + {var}_I10/{constants.DT}')
self.code_lines.append(f' {var}_g2 = {var}_I11 / dt_sqrt')
self.code_lines.append(f' {var}_new = {var} + {constants.DT} * {var}_f_H0s1 + '
f'{var}_g1 * {var}_g_H1s1 + {var}_g2 * {var}_g_H1s2')
f'{var}_g1 * {var}_g_H1s1 + {var}_g2 * {var}_g_H1s2')
self.code_lines.append(' ')

# returns
@@ -367,3 +374,6 @@ class KlPl(SDEIntegrator):
code_lines=self.code_lines,
show_code=self.show_code,
func_name=self.func_name)


register_sde_integrator('klpl', KlPl)

+ 37
- 0
brainpy/integrators/utils.py View File

@@ -4,12 +4,17 @@
import inspect
from pprint import pprint

import brainpy.math as bm
from brainpy.errors import UnsupportedError

from brainpy import errors

__all__ = [
'get_args',
'check_kws',
'compile_code',
'check_inits',
'format_args',
]


@@ -103,3 +108,35 @@ def compile_code(code_lines, code_scope, func_name, show_code=False):
exec(compile(code, '', 'exec'), code_scope)
new_f = code_scope[func_name]
return new_f


def check_inits(inits, variables):
if isinstance(inits, (tuple, list)):
assert len(inits) == len(variables), (f'Then number of variables is {len(variables)}, '
f'however we only got {len(inits)} initial values.')
inits = {v: inits[i] for i, v in enumerate(variables)}
elif isinstance(inits, dict):
assert len(inits) == len(variables), (f'Then number of variables is {len(variables)}, '
f'however we only got {len(inits)} initial values.')
else:
raise UnsupportedError('Only supports dict/sequence of data for initial values. '
f'But we got {type(inits)}: {inits}')
for key in list(inits.keys()):
if key not in variables:
raise ValueError(f'"{key}" is not defined in variables: {variables}')
val = inits[key]
if isinstance(val, (float, int)):
inits[key] = bm.asarray([val], dtype=bm.float_)
return inits


def format_args(args, kwargs, arguments):
all_args = dict()
for i, arg in enumerate(args):
all_args[arguments[i]] = arg
for key, arg in kwargs.items():
if key in all_args:
raise ValueError(f'{key} has been provided in *args, '
f'but we detect it again in **kwargs.')
all_args[key] = arg
return all_args

+ 10
- 10
brainpy/losses/__init__.py View File

@@ -41,7 +41,7 @@ def _return(outputs, reduction):


def cross_entropy_loss(logits, targets, weight=None, reduction='mean'):
"""This criterion combines ``LogSoftmax`` and `NLLLoss`` in one single class.
r"""This criterion combines ``LogSoftmax`` and `NLLLoss`` in one single class.

It is useful when training a classification problem with `C` classes.
If provided, the optional argument :attr:`weight` should be a 1D `Tensor`
@@ -120,7 +120,7 @@ def cross_entropy_loss(logits, targets, weight=None, reduction='mean'):


def cross_entropy_sparse(logits, labels):
"""Computes the softmax cross-entropy loss.
r"""Computes the softmax cross-entropy loss.

Args:
logits: (batch, ..., #class) tensor of logits.
@@ -155,7 +155,7 @@ def cross_entropy_sigmoid(logits, labels):


def l1_loos(logits, targets, reduction='sum'):
"""Creates a criterion that measures the mean absolute error (MAE) between each element in
r"""Creates a criterion that measures the mean absolute error (MAE) between each element in
the logits :math:`x` and targets :math:`y`. It is useful in regression problems.

The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
@@ -207,7 +207,7 @@ def l1_loos(logits, targets, reduction='sum'):


def l2_loss(predicts, targets):
"""Computes the L2 loss.
r"""Computes the L2 loss.

The 0.5 term is standard in "Pattern Recognition and Machine Learning"
by Bishop [1]_, but not "The Elements of Statistical Learning" by Tibshirani.
@@ -246,7 +246,7 @@ def l2_norm(x):


def mean_absolute_error(x, y, axis=None):
"""Computes the mean absolute error between x and y.
r"""Computes the mean absolute error between x and y.

Args:
x: a tensor of shape (d0, .. dN-1).
@@ -261,7 +261,7 @@ def mean_absolute_error(x, y, axis=None):


def mean_squared_error(predicts, targets, axis=None):
"""Computes the mean squared error between x and y.
r"""Computes the mean squared error between x and y.

Args:
predicts: a tensor of shape (d0, .. dN-1).
@@ -276,7 +276,7 @@ def mean_squared_error(predicts, targets, axis=None):


def mean_squared_log_error(y_true, y_pred, axis=None):
"""Computes the mean squared logarithmic error between y_true and y_pred.
r"""Computes the mean squared logarithmic error between y_true and y_pred.

Args:
y_true: a tensor of shape (d0, .. dN-1).
@@ -291,7 +291,7 @@ def mean_squared_log_error(y_true, y_pred, axis=None):


def huber_loss(predicts, targets, delta: float = 1.0):
"""Huber loss.
r"""Huber loss.

Huber loss is similar to L2 loss close to zero, L1 loss away from zero.
If gradient descent is applied to the `huber loss`, it is equivalent to
@@ -353,7 +353,7 @@ def multiclass_logistic_loss(label: int, logits: jn.ndarray) -> float:


def smooth_labels(labels, alpha: float) -> jn.ndarray:
"""Apply label smoothing.
r"""Apply label smoothing.
Label smoothing is often used in combination with a cross-entropy loss.
Smoothed labels favour small logit gaps, and it has been shown that this can
provide better model calibration by preventing overconfident predictions.
@@ -411,7 +411,7 @@ def softmax_cross_entropy(logits, labels):


def log_cosh(predicts, targets=None, ):
"""Calculates the log-cosh loss for a set of predictions.
r"""Calculates the log-cosh loss for a set of predictions.

log(cosh(x)) is approximately `(x**2) / 2` for small x and `abs(x) - log(2)`
for large x. It is a twice differentiable alternative to the Huber loss.


+ 2
- 3
brainpy/math/__init__.py View File

@@ -46,7 +46,7 @@ from . import random
from .autograd import *
from .controls import *
from .jit import *
from .parallels import *
# from .parallels import *

# settings
from . import setting
@@ -56,8 +56,7 @@ from .function import *
# functions
from .activations import *
from . import activations
from .compact import *
from . import special
from .compat import *


def get_dint():


+ 2
- 3
brainpy/math/autograd.py View File

@@ -1,8 +1,7 @@
# -*- coding: utf-8 -*-

from typing import Union, Callable, Dict, Sequence

from functools import partial
from typing import Union, Callable, Dict, Sequence

import jax
import numpy as np
@@ -41,7 +40,7 @@ def _make_cls_call_func(grad_func, grad_tree, grad_vars, dyn_vars,
except UnexpectedTracerError as e:
for v, d in zip(grad_vars, old_grad_vs): v.value = d
for v, d in zip(dyn_vars, old_dyn_vs): v.value = d
raise errors.JaxTracerError(variables=dyn_vars+grad_vars) from e
raise errors.JaxTracerError(variables=dyn_vars + grad_vars) from e
for v, d in zip(grad_vars, new_grad_vs): v.value = d
for v, d in zip(dyn_vars, new_dyn_vs): v.value = d



brainpy/math/compact/__init__.py → brainpy/math/compat/__init__.py View File

@@ -1,8 +1,10 @@
# -*- coding: utf-8 -*-

__all__ = [
'optimizers', 'losses'
'optimizers', 'losses',
'FixedLenDelay',
]

from . import optimizers, losses
from .delay_vars import *


+ 45
- 0
brainpy/math/compat/delay_vars.py View File

@@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-

import warnings
from typing import Union, Callable

import jax.numpy as jnp

from brainpy.math.jaxarray import ndarray
from brainpy.math.numpy_ops import zeros
from brainpy.math.delay_vars import TimeDelay


__all__ = [
'FixedLenDelay'
]


def FixedLenDelay(shape,
delay_len: Union[float, int],
before_t0: Union[Callable, ndarray, jnp.ndarray, float, int] = None,
t0: Union[float, int] = 0.,
dt: Union[float, int] = None,
name: str = None,
interp_method='linear_interp', ):
"""Delay variable which has a fixed delay length.

.. deprecated:: 2.1.2
Please use "brainpy.math.TimeDelay" instead.

See Also
--------
TimeDelay

"""
warnings.warn('Please use "brainpy.math.TimeDelay" instead. '
'"brainpy.math.FixedLenDelay" is deprecated since version 2.1.2. ',
DeprecationWarning)
return TimeDelay(inits=zeros(shape),
delay_len=delay_len,
before_t0=before_t0,
t0=t0,
dt=dt,
name=name,
interp_method=interp_method)


brainpy/math/compact/losses.py → brainpy/math/compat/losses.py View File

@@ -17,6 +17,11 @@ __all__ = [


def cross_entropy_loss(*args, **kwargs):
"""Cross entropy loss.

.. deprecated:: 2.1.0
Please use "brainpy.losses.cross_entropy_loss" instead.
"""
warnings.warn('Please use "brainpy.losses.XXX" instead. '
'"brainpy.math.losses.XXX" is deprecated since version 2.0.3. ',
DeprecationWarning)
@@ -24,6 +29,11 @@ def cross_entropy_loss(*args, **kwargs):


def l1_loos(*args, **kwargs):
"""L1 loss.

.. deprecated:: 2.1.0
Please use "brainpy.losses.l1_loss" instead.
"""
warnings.warn('Please use "brainpy.losses.XXX" instead. '
'"brainpy.math.losses.XXX" is deprecated since version 2.0.3. ',
DeprecationWarning)
@@ -31,6 +41,11 @@ def l1_loos(*args, **kwargs):


def l2_loss(*args, **kwargs):
"""L2 loss.

.. deprecated:: 2.1.0
Please use "brainpy.losses.l2_loss" instead.
"""
warnings.warn('Please use "brainpy.losses.XXX" instead. '
'"brainpy.math.losses.XXX" is deprecated since version 2.0.3. ',
DeprecationWarning)
@@ -38,6 +53,11 @@ def l2_loss(*args, **kwargs):


def l2_norm(*args, **kwargs):
"""L2 normal.

.. deprecated:: 2.1.0
Please use "brainpy.losses.l2_norm" instead.
"""
warnings.warn('Please use "brainpy.losses.XXX" instead. '
'"brainpy.math.losses.XXX" is deprecated since version 2.0.3. ',
DeprecationWarning)
@@ -45,6 +65,11 @@ def l2_norm(*args, **kwargs):


def huber_loss(*args, **kwargs):
"""Huber loss.

.. deprecated:: 2.1.0
Please use "brainpy.losses.huber_loss" instead.
"""
warnings.warn('Please use "brainpy.losses.XXX" instead. '
'"brainpy.math.losses.XXX" is deprecated since version 2.0.3. ',
DeprecationWarning)
@@ -52,6 +77,11 @@ def huber_loss(*args, **kwargs):


def mean_absolute_error(*args, **kwargs):
"""mean absolute error loss.

.. deprecated:: 2.1.0
Please use "brainpy.losses.mean_absolute_error" instead.
"""
warnings.warn('Please use "brainpy.losses.XXX" instead. '
'"brainpy.math.losses.XXX" is deprecated since version 2.0.3. ',
DeprecationWarning)
@@ -59,6 +89,11 @@ def mean_absolute_error(*args, **kwargs):


def mean_squared_error(*args, **kwargs):
"""Mean squared error loss.

.. deprecated:: 2.1.0
Please use "brainpy.losses.mean_squared_error" instead.
"""
warnings.warn('Please use "brainpy.losses.XXX" instead. '
'"brainpy.math.losses.XXX" is deprecated since version 2.0.3. ',
DeprecationWarning)
@@ -66,6 +101,11 @@ def mean_squared_error(*args, **kwargs):


def mean_squared_log_error(*args, **kwargs):
"""Mean squared log error loss.

.. deprecated:: 2.1.0
Please use "brainpy.losses.mean_squared_log_error" instead.
"""
warnings.warn('Please use "brainpy.losses.XXX" instead. '
'"brainpy.math.losses.XXX" is deprecated since version 2.0.3. ',
DeprecationWarning)

brainpy/math/compact/optimizers.py → brainpy/math/compat/optimizers.py View File

@@ -22,6 +22,11 @@ __all__ = [


def SGD(*args, **kwargs):
"""SGD optimizer.

.. deprecated:: 2.1.0
Please use "brainpy.optim.SGD" instead.
"""
warnings.warn('Please use "brainpy.optim.SGD" instead. '
'"brainpy.math.optimizers.SGD" is '
'deprecated since version 2.0.3. ',
@@ -30,6 +35,11 @@ def SGD(*args, **kwargs):


def Momentum(*args, **kwargs):
"""Momentum optimizer.

.. deprecated:: 2.1.0
Please use "brainpy.optim.Momentum" instead.
"""
warnings.warn('Please use "brainpy.optim.Momentum" instead. '
'"brainpy.math.optimizers.Momentum" is '
'deprecated since version 2.0.3. ',
@@ -38,6 +48,11 @@ def Momentum(*args, **kwargs):


def MomentumNesterov(*args, **kwargs):
"""MomentumNesterov optimizer.

.. deprecated:: 2.1.0
Please use "brainpy.optim.MomentumNesterov" instead.
"""
warnings.warn('Please use "brainpy.optim.MomentumNesterov" instead. '
'"brainpy.math.optimizers.MomentumNesterov" is '
'deprecated since version 2.0.3. ',
@@ -46,6 +61,11 @@ def MomentumNesterov(*args, **kwargs):


def Adagrad(*args, **kwargs):
"""Adagrad optimizer.

.. deprecated:: 2.1.0
Please use "brainpy.optim.Adagrad" instead.
"""
warnings.warn('Please use "brainpy.optim.Adagrad" instead. '
'"brainpy.math.optimizers.Adagrad" is '
'deprecated since version 2.0.3. ',
@@ -54,6 +74,11 @@ def Adagrad(*args, **kwargs):


def Adadelta(*args, **kwargs):
"""Adadelta optimizer.

.. deprecated:: 2.1.0
Please use "brainpy.optim.Adadelta" instead.
"""
warnings.warn('Please use "brainpy.optim.Adadelta" instead. '
'"brainpy.math.optimizers.Adadelta" is '
'deprecated since version 2.0.3. ',
@@ -62,6 +87,11 @@ def Adadelta(*args, **kwargs):


def RMSProp(*args, **kwargs):
"""RMSProp optimizer.

.. deprecated:: 2.1.0
Please use "brainpy.optim.RMSProp" instead.
"""
warnings.warn('Please use "brainpy.optim.RMSProp" instead. '
'"brainpy.math.optimizers.RMSProp" is '
'deprecated since version 2.0.3. ',
@@ -70,6 +100,11 @@ def RMSProp(*args, **kwargs):


def Adam(*args, **kwargs):
"""Adam optimizer.

.. deprecated:: 2.1.0
Please use "brainpy.optim.Adam" instead.
"""
warnings.warn('Please use "brainpy.optim.Adam" instead. '
'"brainpy.math.optimizers.Adam" is '
'deprecated since version 2.0.3. ',
@@ -78,6 +113,11 @@ def Adam(*args, **kwargs):


def Constant(*args, **kwargs):
"""Constant scheduler.

.. deprecated:: 2.1.0
Please use "brainpy.optim.Constant" instead.
"""
warnings.warn('Please use "brainpy.optim.Constant" instead. '
'"brainpy.math.optimizers.Constant" is '
'deprecated since version 2.0.3. ',
@@ -86,6 +126,11 @@ def Constant(*args, **kwargs):


def ExponentialDecay(*args, **kwargs):
"""ExponentialDecay scheduler.

.. deprecated:: 2.1.0
Please use "brainpy.optim.ExponentialDecay" instead.
"""
warnings.warn('Please use "brainpy.optim.ExponentialDecay" instead. '
'"brainpy.math.optimizers.ExponentialDecay" is '
'deprecated since version 2.0.3. ',
@@ -94,6 +139,11 @@ def ExponentialDecay(*args, **kwargs):


def InverseTimeDecay(*args, **kwargs):
"""InverseTimeDecay scheduler.

.. deprecated:: 2.1.0
Please use "brainpy.optim.InverseTimeDecay" instead.
"""
warnings.warn('Please use "brainpy.optim.InverseTimeDecay" instead. '
'"brainpy.math.optimizers.InverseTimeDecay" is '
'deprecated since version 2.0.3. ',
@@ -102,6 +152,11 @@ def InverseTimeDecay(*args, **kwargs):


def PolynomialDecay(*args, **kwargs):
"""PolynomialDecay scheduler.

.. deprecated:: 2.1.0
Please use "brainpy.optim.PolynomialDecay" instead.
"""
warnings.warn('Please use "brainpy.optim.PolynomialDecay" instead. '
'"brainpy.math.optimizers.PolynomialDecay" is '
'deprecated since version 2.0.3. ',
@@ -110,6 +165,11 @@ def PolynomialDecay(*args, **kwargs):


def PiecewiseConstant(*args, **kwargs):
"""PiecewiseConstant scheduler.

.. deprecated:: 2.1.0
Please use "brainpy.optim.PiecewiseConstant" instead.
"""
warnings.warn('Please use "brainpy.optim.PiecewiseConstant" instead. '
'"brainpy.math.optimizers.PiecewiseConstant" is '
'deprecated since version 2.0.3. ',

+ 183
- 107
brainpy/math/delay_vars.py View File

@@ -1,41 +1,47 @@
# -*- coding: utf-8 -*-


from typing import Union, Callable, Tuple
from typing import Union, Callable

import jax.numpy as jnp
import numpy as np
from jax import vmap
from jax.experimental.host_callback import id_tap
from jax.lax import cond

from brainpy import math as bm
from brainpy import check
from brainpy.base.base import Base
from brainpy.tools.checking import check_float
from brainpy.tools.others import to_size
from brainpy.errors import UnsupportedError
from brainpy.math import numpy_ops as ops
from brainpy.math.jaxarray import ndarray, Variable
from brainpy.math.setting import get_dt
from brainpy.tools.checking import check_float, check_integer

__all__ = [
'AbstractDelay',
'FixedLenDelay',
'TimeDelay',
'NeutralDelay',
'LengthDelay',
]


class AbstractDelay(Base):
def update(self, time, value):
def update(self, *args, **kwargs):
raise NotImplementedError


_FUNC_BEFORE = 'function'
_DATA_BEFORE = 'data'
_INTERP_LINEAR = 'linear_interp'
_INTERP_ROUND = 'round'


class FixedLenDelay(AbstractDelay):
"""Delay variable which has a fixed delay length.
class TimeDelay(AbstractDelay):
"""Delay variable which has a fixed delay time length.

For example, we create a delay variable which has a maximum delay length of 1 ms

>>> import brainpy.math as bm
>>> delay = bm.FixedLenDelay(bm.zeros(3), delay_len=1., dt=0.1)
>>> delay = bm.TimeDelay(bm.zeros(3), delay_len=1., dt=0.1)
>>> delay(-0.5)
[-0. -0. -0.]

@@ -43,13 +49,13 @@ class FixedLenDelay(AbstractDelay):

1. the one-dimensional delay data

>>> delay = bm.FixedLenDelay(3, delay_len=1., dt=0.1, before_t0=lambda t: t)
>>> delay = bm.TimeDelay(bm.zeros(3), delay_len=1., dt=0.1, before_t0=lambda t: t)
>>> delay(-0.2)
[-0.2 -0.2 -0.2]

2. the two-dimensional delay data

>>> delay = bm.FixedLenDelay((3, 2), delay_len=1., dt=0.1, before_t0=lambda t: t)
>>> delay = bm.TimeDelay(bm.zeros((3, 2)), delay_len=1., dt=0.1, before_t0=lambda t: t)
>>> delay(-0.6)
[[-0.6 -0.6]
[-0.6 -0.6]
@@ -57,8 +63,8 @@ class FixedLenDelay(AbstractDelay):

3. the three-dimensional delay data

>>> delay = bm.FixedLenDelay((3, 2, 1), delay_len=1., dt=0.1, before_t0=lambda t: t)
>>> delay(-0.6)
>>> delay = bm.TimeDelay(bm.zeros((3, 2, 1)), delay_len=1., dt=0.1, before_t0=lambda t: t)
>>> delay(-0.8)
[[[-0.8]
[-0.8]]
[[-0.8]
@@ -68,8 +74,8 @@ class FixedLenDelay(AbstractDelay):

Parameters
----------
shape: int, sequence of int
The delay data shape.
inits: int, sequence of int
The initial delay data.
t0: float, int
The zero time.
delay_len: float, int
@@ -83,155 +89,225 @@ class FixedLenDelay(AbstractDelay):
of :math:`(num_delay, ...)`, where the longest delay data is aranged in
the first index.
name: str
The delay instance name.
interp_method: str
The way to deal with the delay at the time which is not integer times of the time step.
For exameple, if the time step ``dt=0.1``, the time delay length ``delay_len=1.``,
when users require the delay data at ``t-0.53``, we can deal this situation with
the following methods:

- ``"linear_interp"``: using linear interpolation to get the delay value
at the required time (default).
- ``"round"``: round the time to make it is the integer times of the time step. For
the above situation, we will use the time at ``t-0.5`` to approximate the delay data
at ``t-0.53``.

.. versionadded:: 2.1.1

See Also
--------
LengthDelay
"""

def __init__(
self,
shape: Union[int, Tuple[int, ...]],
inits: Union[ndarray, jnp.ndarray],
delay_len: Union[float, int],
before_t0: Union[Callable, bm.ndarray, jnp.ndarray, float, int] = None,
before_t0: Union[Callable, ndarray, jnp.ndarray, float, int] = None,
t0: Union[float, int] = 0.,
dt: Union[float, int] = None,
name: str = None,
dtype=None,
interp_method='linear_interp',
):
super(FixedLenDelay, self).__init__(name=name)
super(TimeDelay, self).__init__(name=name)

# shape
self.shape = to_size(shape)
self.dtype = dtype
assert isinstance(inits, (ndarray, np.ndarray)), (f'Must be an instance of brainpy.math.ndarray '
f'or jax.numpy.ndarray. But we got {type(inits)}')
self.shape = inits.shape

# delay_len
self.t0 = t0
self._dt = bm.get_dt() if dt is None else dt
self.dt = get_dt() if dt is None else dt
check_float(delay_len, 'delay_len', allow_none=False, allow_int=True, min_bound=0.)
self._delay_len = delay_len
self.delay_len = delay_len + self._dt
self.num_delay_steps = int(bm.ceil(self.delay_len / self._dt).value)
self.delay_len = delay_len
self.num_delay_step = int(ops.ceil(self.delay_len / self.dt).value) + 1

# interp method
if interp_method not in [_INTERP_LINEAR, _INTERP_ROUND]:
raise UnsupportedError(f'Un-supported interpolation method {interp_method}, '
f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')
self.interp_method = interp_method

# other variables
self._idx = bm.Variable(bm.asarray([0]))
# time variables
self.idx = ops.Variable(ops.asarray([0]))
check_float(t0, 't0', allow_none=False, allow_int=True, )
self._current_time = bm.Variable(bm.asarray([t0]))
self.current_time = Variable(ops.asarray([t0]))

# delay data
self._data = bm.Variable(bm.zeros((self.num_delay_steps,) + self.shape, dtype=dtype))
self.data = Variable(ops.zeros((self.num_delay_step,) + self.shape,
dtype=inits.dtype))
if before_t0 is None:
self._before_type = _DATA_BEFORE
elif callable(before_t0):
self._before_t0 = lambda t: jnp.asarray(bm.broadcast_to(before_t0(t), self.shape).value,
dtype=self.dtype)
self._before_t0 = lambda t: jnp.asarray(ops.broadcast_to(before_t0(t), self.shape).value,
dtype=inits.dtype)
self._before_type = _FUNC_BEFORE
elif isinstance(before_t0, (bm.ndarray, jnp.ndarray, float, int)):
elif isinstance(before_t0, (ndarray, jnp.ndarray, float, int)):
self._before_type = _DATA_BEFORE
try:
self._data[:] = before_t0
except:
raise ValueError(f'Cannot set delay data by using "before_t0". '
f'The delay data has the shape of '
f'{((self.num_delay_steps,) + self.shape)}, while '
f'we got "before_t0" of {bm.asarray(before_t0).shape}. '
f'They are not compatible. Note that the delay length '
f'{self._delay_len} will automatically add a dt {self.dt} '
f'to {self.delay_len}.')
self.data[:-1] = before_t0
else:
raise ValueError(f'"before_t0" does not support {type(before_t0)}: before_t0')

@property
def idx(self):
return self._idx

@idx.setter
def idx(self, value):
raise ValueError('Cannot set "idx" by users.')

@property
def dt(self):
return self._dt

@dt.setter
def dt(self, value):
raise ValueError('Cannot set "dt" by users.')

@property
def data(self):
return self._data

@data.setter
def data(self, value):
self._data[:] = value
raise ValueError(f'"before_t0" does not support {type(before_t0)}')
# set initial data
self.data[-1] = inits

@property
def current_time(self):
return self._current_time[0]
# interpolation function
self.f = jnp.interp
for dim in range(1, len(self.shape) + 1, 1):
self.f = vmap(self.f, in_axes=(None, None, dim), out_axes=dim - 1)

def _check_time(self, times, transforms):
prev_time, current_time = times
current_time = bm.as_device_array(current_time)
prev_time = bm.as_device_array(prev_time)
if prev_time > current_time:
current_time = current_time[0]
if prev_time > current_time + 1e-6:
raise ValueError(f'\n'
f'!!! Error in {self.__class__.__name__}: \n'
f'The request time should be less than the '
f'current time {current_time}. But we '
f'got {prev_time} > {current_time}')
lower_time = jnp.asarray(current_time - self.delay_len)
if prev_time < lower_time:
lower_time = current_time - self.delay_len
if prev_time < lower_time - self.dt:
raise ValueError(f'\n'
f'!!! Error in {self.__class__.__name__}: \n'
f'The request time of the variable should be in '
f'[{lower_time}, {current_time}], but we got {prev_time}')

def __call__(self, prev_time):
def __call__(self, time, indices=None):
# check
id_tap(self._check_time, (prev_time, self.current_time))
if check.is_checking():
id_tap(self._check_time, (time, self.current_time))
if self._before_type == _FUNC_BEFORE:
return cond(prev_time < self.t0,
return cond(time < self.t0,
self._before_t0,
self._fn1,
prev_time)
self._after_t0,
time)
else:
return self._fn1(prev_time)

def _fn1(self, prev_time):
diff = self.delay_len - (self.current_time - prev_time)
if isinstance(diff, bm.ndarray): diff = diff.value
req_num_step = jnp.asarray(diff / self._dt, dtype=bm.get_dint())
extra = diff - req_num_step * self._dt
return cond(extra == 0., self._true_fn, self._false_fn, (req_num_step, extra))
return self._after_t0(time)

def _after_t0(self, prev_time):
diff = self.delay_len - (self.current_time[0] - prev_time)
if isinstance(diff, ndarray):
diff = diff.value
if self.interp_method == _INTERP_LINEAR:
req_num_step = jnp.asarray(diff / self.dt, dtype=ops.int32)
extra = diff - req_num_step * self.dt
return cond(extra == 0., self._true_fn, self._false_fn, (req_num_step, extra))
elif self.interp_method == _INTERP_ROUND:
req_num_step = jnp.asarray(jnp.round(diff / self.dt), dtype=ops.int32)
return self._true_fn([req_num_step, 0.])
else:
raise UnsupportedError(f'Un-supported interpolation method {self.interp_method}, '
f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')

def _true_fn(self, div_mod):
req_num_step, extra = div_mod
return self._data[self.idx[0] + req_num_step]
return self.data[self.idx[0] + req_num_step]

def _false_fn(self, div_mod):
req_num_step, extra = div_mod
f = jnp.interp
for dim in range(1, len(self.shape) + 1, 1):
f = vmap(f, in_axes=(None, None, dim), out_axes=dim - 1)
idx = jnp.asarray([self.idx[0] + req_num_step,
self.idx[0] + req_num_step + 1])
idx %= self.num_delay_steps
return f(extra, jnp.asarray([0., self._dt]), self._data[idx])
idx %= self.num_delay_step
return self.f(extra, jnp.asarray([0., self.dt]), self.data[idx])

def update(self, time, value):
self._data[self._idx[0]] = value
# check_float(time, 'time', allow_none=False, allow_int=True)
self._current_time[0] = time
self._idx.value = (self._idx + 1) % self.num_delay_steps
self.data[self.idx[0]] = value
self.current_time[0] = time
self.idx.value = (self.idx + 1) % self.num_delay_step


class NeutralDelay(TimeDelay):
pass


class VariedLenDelay(AbstractDelay):
"""Delay variable which has a functional delay
class LengthDelay(AbstractDelay):
"""Delay variable which has a fixed delay length.

Parameters
----------
inits: int, sequence of int
The initial delay data.
delay_len: int
The maximum delay length.
delay_data: Tensor
The delay data.
name: str
The delay object name.

See Also
--------
TimeDelay
"""

def update(self, time, value):
pass
def __init__(
self,
inits: Union[ndarray, jnp.ndarray],
delay_len: int,
delay_data: Union[ndarray, jnp.ndarray, float, int] = None,
name: str = None,
):
super(LengthDelay, self).__init__(name=name)
self.init(inits, delay_len, delay_data)

def __init__(self):
super(VariedLenDelay, self).__init__()
def init(self, inits, delay_len, delay_data):
assert isinstance(inits, (ndarray, np.ndarray)), (f'Must be an instance of brainpy.math.ndarray '
f'or jax.numpy.ndarray. But we got {type(inits)}')
self.shape = inits.shape

# delay_len
check_integer(delay_len, 'delay_len', allow_none=False, min_bound=0)
self.delay_len = delay_len
self.num_delay_step = delay_len + 1

class NeutralDelay(FixedLenDelay):
pass
# time variables
self.idx = Variable(ops.asarray([0], dtype=ops.int32))

# delay data
self.data = Variable(ops.zeros((self.num_delay_step,) + self.shape,
dtype=inits.dtype))
if delay_data is None:
pass
elif isinstance(delay_data, (ndarray, jnp.ndarray, float, int)):
self.data[:-1] = delay_data
else:
raise ValueError(f'"delay_data" does not support {type(delay_data)}')

def _check_delay(self, delay_len, transforms):
if isinstance(delay_len, ndarray):
delay_len = delay_len.value
if np.any(delay_len >= self.num_delay_step):
raise ValueError(f'\n'
f'!!! Error in {self.__class__.__name__}: \n'
f'The request delay length should be less than the '
f'maximum delay {self.delay_len}. But we '
f'got {delay_len}')

def __call__(self, delay_len, indices=None):
# check
if check.is_checking():
id_tap(self._check_delay, delay_len)
# the delay length
delay_idx = (self.idx[0] - delay_len - 1) % self.num_delay_step
if delay_idx.dtype not in [ops.int32, ops.int64]:
raise ValueError(f'"delay_len" must be integer, but we got {delay_len}')
# the delay data
if indices is None:
return self.data[delay_idx]
else:
return self.data[delay_idx, indices]

def update(self, value):
if ops.shape(value) != self.shape:
raise ValueError(f'value shape should be {self.shape}, but we got {ops.shape(value)}')
self.data[self.idx[0]] = value
self.idx.value = (self.idx + 1) % self.num_delay_step

+ 8
- 8
brainpy/math/numpy_ops.py View File

@@ -79,7 +79,7 @@ __all__ = [
'setxor1d', 'tensordot', 'trim_zeros', 'union1d', 'unravel_index', 'unwrap', 'take_along_axis',

# others
'clip_by_norm', 'as_device_array', 'as_variable', 'as_jaxarray', 'as_numpy',
'clip_by_norm', 'as_device_array', 'as_variable', 'as_numpy',
]

_min = min
@@ -89,6 +89,10 @@ _max = max
# others
# ------

# def as_jax_array(tensor):
# return asarray(tensor)


def as_device_array(tensor):
if isinstance(tensor, JaxArray):
return tensor.value
@@ -111,10 +115,6 @@ def as_variable(tensor):
return Variable(asarray(tensor))


def as_jaxarray(tensor):
return asarray(tensor)


def _remove_jaxarray(obj):
if isinstance(obj, JaxArray):
return obj.value
@@ -1507,10 +1507,10 @@ def vander(x, N=None, increasing=False):


def fill_diagonal(a, val):
a = _remove_jaxarray(a)
assert a.ndim >= 2
assert isinstance(a, JaxArray), f'Must be a JaxArray, but got {type(a)}'
assert a.ndim >= 2, f'Only support tensor has dimension >= 2, but got {a.shape}'
i, j = jnp.diag_indices(_min(a.shape[-2:]))
return JaxArray(a.at[..., i, j].set(val))
a._value = a.value.at[..., i, j].set(val)


# indexing funcs


+ 32
- 24
brainpy/math/parallels.py View File

@@ -27,6 +27,7 @@ from brainpy import errors
from brainpy.base.base import Base
from brainpy.base.collector import TensorCollector
from brainpy.math.random import RandomState
from brainpy.math.jaxarray import JaxArray
from brainpy.tools.codes import change_func_name

__all__ = [
@@ -35,29 +36,31 @@ __all__ = [
]


def _make_vmap(func, dyn_vars, rand_vars, in_axes, out_axes,
batch_idx, axis_name, reduce_func, f_name=None):
def _make_vmap(func, nonbatched_vars, batched_vars, in_axes, out_axes,
batch_idx, axis_name, f_name=None):
@functools.partial(jax.vmap, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name)
def vmapped_func(dyn_data, rand_data, *args, **kwargs):
dyn_vars.assign(dyn_data)
rand_vars.assign(rand_data)
def vmapped_func(nonbatched_data, batched_data, *args, **kwargs):
nonbatched_vars.assign(nonbatched_data)
batched_vars.assign(batched_data)
out = func(*args, **kwargs)
dyn_changes = dyn_vars.dict()
rand_changes = rand_vars.dict()
return out, dyn_changes, rand_changes
nonbatched_changes = nonbatched_vars.dict()
batched_changes = batched_vars.dict()
return nonbatched_changes, batched_changes, out

def call(*args, **kwargs):
dyn_data = dyn_vars.dict()
n = args[batch_idx[0]].shape[batch_idx[1]]
rand_data = {key: val.split_keys(n) for key, val in rand_vars.items()}
nonbatched_data = nonbatched_vars.dict()
batched_data = {key: val.split_keys(n) for key, val in batched_vars.items()}
try:
out, dyn_changes, rand_changes = vmapped_func(dyn_data, rand_data, *args, **kwargs)
out, dyn_changes, rand_changes = vmapped_func(nonbatched_data, batched_data, *args, **kwargs)
except UnexpectedTracerError as e:
dyn_vars.assign(dyn_data)
rand_vars.assign(rand_data)
raise errors.JaxTracerError(variables=dyn_vars) from e
for key, v in dyn_changes.items(): dyn_vars[key] = reduce_func(v)
for key, v in rand_changes.items(): rand_vars[key] = reduce_func(v)
nonbatched_vars.assign(nonbatched_data)
batched_vars.assign(batched_data)
raise errors.JaxTracerError() from e
# for key, v in dyn_changes.items():
# dyn_vars[key] = reduce_func(v)
# for key, v in rand_changes.items():
# rand_vars[key] = reduce_func(v)
return out

return change_func_name(name=f_name, f=call) if f_name else call
@@ -77,7 +80,7 @@ def vmap(func, dyn_vars=None, batched_vars=None,
----------
func : Base, function, callable
The function or the module to compile.
dyn_vars : dict
dyn_vars : dict, sequence
batched_vars : dict
in_axes : optional, int, sequence of int
Specify which input array axes to map over. If each positional argument to
@@ -207,13 +210,19 @@ def vmap(func, dyn_vars=None, batched_vars=None,
axis_name=axis_name)

else:
if isinstance(dyn_vars, JaxArray):
dyn_vars = [dyn_vars]
if isinstance(dyn_vars, (tuple, list)):
dyn_vars = {f'_vmap_v{i}': v for i, v in enumerate(dyn_vars)}
assert isinstance(dyn_vars, dict)

# dynamical variables
dyn_vars, rand_vars = TensorCollector(), TensorCollector()
_dyn_vars, _rand_vars = TensorCollector(), TensorCollector()
for key, val in dyn_vars.items():
if isinstance(val, RandomState):
rand_vars[key] = val
_rand_vars[key] = val
else:
dyn_vars[key] = val
_dyn_vars[key] = val

# in axes
if in_axes is None:
@@ -249,13 +258,12 @@ def vmap(func, dyn_vars=None, batched_vars=None,

# jit function
return _make_vmap(func=func,
dyn_vars=dyn_vars,
rand_vars=rand_vars,
nonbatched_vars=_dyn_vars,
batched_vars=_rand_vars,
in_axes=in_axes,
out_axes=out_axes,
axis_name=axis_name,
batch_idx=batch_idx,
reduce_func=reduce_func)
batch_idx=batch_idx)

else:
raise errors.BrainPyError(f'Only support instance of {Base.__name__}, or a callable '


+ 0
- 71
brainpy/math/special.py View File

@@ -1,71 +0,0 @@
# -*- coding: utf-8 -*-

import jax.numpy as jnp

from brainpy.tools.checking import check_integer

__all__ = [
'poch',
'Gamma',
'Beta',
]


def poch(a, n):
""" Returns the Pochhammer symbol (a)_n. """
# First, check if 'a' is a real number (this is currently only working for reals).
assert not isinstance(a, complex), "a must be real: %r" % a
check_integer(n, allow_none=False, min_bound=0)
# Compute the Pochhammer symbol.
return 1.0 if n == 0 else jnp.prod(jnp.arange(n) + a)


def Gamma(z):
""" Paul Godfrey's Gamma function implementation valid for z complex.
This is converted from Godfrey's Gamma.m Matlab file available at
https://www.mathworks.com/matlabcentral/fileexchange/3572-gamma.
15 significant digits of accuracy for real z and 13
significant digits for other values.
"""
zz = z

# Find negative real parts of z and make them positive.
if isinstance(z, (complex, jnp.complex64, jnp.complex128)):
Z = [z.real, z.imag]
if Z[0] < 0:
Z[0] = -Z[0]
z = jnp.asarray(Z)
z = z.astype(complex)

g = 607 / 128.
c = jnp.asarray([0.99999999999999709182, 57.156235665862923517, -59.597960355475491248,
14.136097974741747174, -0.49191381609762019978, .33994649984811888699e-4,
.46523628927048575665e-4, -.98374475304879564677e-4, .15808870322491248884e-3,
-.21026444172410488319e-3, .21743961811521264320e-3, -.16431810653676389022e-3,
.84418223983852743293e-4, -.26190838401581408670e-4, .36899182659531622704e-5])
if z == 0 or z == 1:
return 1.
if ((jnp.round(zz) == zz)
and (zz.imag == 0)
and (zz.real <= 0)): # Adjust for negative poles.
return jnp.inf
z = z - 1
zh = z + 0.5
zgh = zh + g
zp = zgh ** (zh * 0.5) # Trick for avoiding floating-point overflow above z = 141.
idx = jnp.arange(len(c) - 1, 0, -1)
ss = jnp.sum(c[idx] / (z + idx))
sq2pi = 2.5066282746310005024157652848110
f = (sq2pi * (c[0] + ss)) * ((zp * jnp.exp(-zgh)) * zp)
if isinstance(zz, (complex, jnp.complex64, jnp.complex128)):
return f.astype(complex)
elif isinstance(zz, int) and zz >= 0:
f = jnp.round(f)
return f.astype(int)
else:
return f


def Beta(x, y):
""" Beta function using Godfrey's Gamma function. """
return Gamma(x) * Gamma(y) / Gamma(x + y)

+ 78
- 23
brainpy/math/tests/test_delay_vars.py View File

@@ -5,50 +5,66 @@ import unittest
import brainpy.math as bm


class TestFixedLenDelay(unittest.TestCase):
class TestTimeDelay(unittest.TestCase):
def test_dim1(self):
bm.enable_x64()

# linear interp
t0 = 0.
before_t0 = bm.repeat(bm.arange(11).reshape((-1, 1)), 10, axis=1)
delay = bm.FixedLenDelay(10, delay_len=1., t0=t0, dt=0.1, before_t0=before_t0)
self.assertTrue(bm.array_equal(delay(t0 - 0.1), bm.ones(10) * 10))
self.assertTrue(bm.array_equal(delay(t0 - 0.15), bm.ones(10) * 9.5))
before_t0 = bm.repeat(bm.arange(10).reshape((-1, 1)), 10, axis=1)
delay = bm.TimeDelay(bm.zeros(10), delay_len=1., t0=t0, dt=0.1, before_t0=before_t0)
print(delay(t0 - 0.1))
print(delay(t0 - 0.15))
self.assertTrue(bm.array_equal(delay(t0 - 0.1), bm.ones(10) * 9.))
self.assertTrue(bm.array_equal(delay(t0 - 0.15), bm.ones(10) * 8.5))
print()
print(delay(t0 - 0.23))
print(delay(t0 - 0.23) - bm.ones(10) * 8.7)
# self.assertTrue(bm.array_equal(delay(t0 - 0.23), bm.ones(10) * 8.7))

# round interp
delay = bm.TimeDelay(bm.zeros(10), delay_len=1., t0=t0, dt=0.1, before_t0=before_t0,
interp_method='round')
self.assertTrue(bm.array_equal(delay(t0 - 0.1), bm.ones(10) * 9))
print(delay(t0 - 0.15))
self.assertTrue(bm.array_equal(delay(t0 - 0.15), bm.ones(10) * 8))
self.assertTrue(bm.array_equal(delay(t0 - 0.2), bm.ones(10) * 8))

def test_dim2(self):
t0 = 0.
before_t0 = bm.repeat(bm.arange(11).reshape((-1, 1)), 10, axis=1)
before_t0 = bm.repeat(before_t0.reshape((11, 10, 1)), 5, axis=2)
delay = bm.FixedLenDelay((10, 5), delay_len=1., t0=t0, dt=0.1, before_t0=before_t0)
self.assertTrue(bm.array_equal(delay(t0 - 0.1), bm.ones((10, 5)) * 10))
self.assertTrue(bm.array_equal(delay(t0 - 0.15), bm.ones((10, 5)) * 9.5))
before_t0 = bm.repeat(bm.arange(10).reshape((-1, 1)), 10, axis=1)
before_t0 = bm.repeat(before_t0.reshape((10, 10, 1)), 5, axis=2)
delay = bm.TimeDelay(bm.zeros((10, 5)), delay_len=1., t0=t0, dt=0.1, before_t0=before_t0)
self.assertTrue(bm.array_equal(delay(t0 - 0.1), bm.ones((10, 5)) * 9))
self.assertTrue(bm.array_equal(delay(t0 - 0.15), bm.ones((10, 5)) * 8.5))
# self.assertTrue(bm.array_equal(delay(t0 - 0.23), bm.ones((10, 5)) * 8.7))

def test_dim3(self):
t0 = 0.
before_t0 = bm.repeat(bm.arange(11).reshape((-1, 1)), 10, axis=1)
before_t0 = bm.repeat(before_t0.reshape((11, 10, 1)), 5, axis=2)
before_t0 = bm.repeat(before_t0.reshape((11, 10, 5, 1)), 3, axis=3)
delay = bm.FixedLenDelay((10, 5, 3), delay_len=1., t0=t0, dt=0.1, before_t0=before_t0)
self.assertTrue(bm.array_equal(delay(t0 - 0.1), bm.ones((10, 5, 3)) * 10))
self.assertTrue(bm.array_equal(delay(t0 - 0.15), bm.ones((10, 5, 3)) * 9.5))
before_t0 = bm.repeat(bm.arange(10).reshape((-1, 1)), 10, axis=1)
before_t0 = bm.repeat(before_t0.reshape((10, 10, 1)), 5, axis=2)
before_t0 = bm.repeat(before_t0.reshape((10, 10, 5, 1)), 3, axis=3)
delay = bm.TimeDelay(bm.zeros((10, 5, 3)), delay_len=1., t0=t0, dt=0.1, before_t0=before_t0)
self.assertTrue(bm.array_equal(delay(t0 - 0.1), bm.ones((10, 5, 3)) * 9))
self.assertTrue(bm.array_equal(delay(t0 - 0.15), bm.ones((10, 5, 3)) * 8.5))
# self.assertTrue(bm.array_equal(delay(t0 - 0.23), bm.ones((10, 5, 3)) * 8.7))

def test1(self):
print()
delay = bm.FixedLenDelay(3, delay_len=1., dt=0.1, before_t0=lambda t: t)
delay = bm.TimeDelay(bm.zeros(3), delay_len=1., dt=0.1, before_t0=lambda t: t)
print(delay(-0.2))
delay = bm.FixedLenDelay((3, 2), delay_len=1., dt=0.1, before_t0=lambda t: t)
delay = bm.TimeDelay(bm.zeros((3, 2)), delay_len=1., dt=0.1, before_t0=lambda t: t)
print(delay(-0.6))
delay = bm.FixedLenDelay((3, 2, 1), delay_len=1., dt=0.1, before_t0=lambda t: t)
delay = bm.TimeDelay(bm.zeros((3, 2, 1)), delay_len=1., dt=0.1, before_t0=lambda t: t)
print(delay(-0.8))

def test_current_time2(self):
print()
delay = bm.FixedLenDelay(3, delay_len=1., dt=0.1, before_t0=lambda t: t)
delay = bm.TimeDelay(bm.zeros(3), delay_len=1., dt=0.1, before_t0=lambda t: t)
print(delay(0.))
before_t0 = bm.repeat(bm.arange(11).reshape((-1, 1)), 10, axis=1)
before_t0 = bm.repeat(before_t0.reshape((11, 10, 1)), 5, axis=2)
delay = bm.FixedLenDelay((10, 5), delay_len=1., dt=0.1, before_t0=before_t0)
before_t0 = bm.repeat(bm.arange(10).reshape((-1, 1)), 10, axis=1)
before_t0 = bm.repeat(before_t0.reshape((10, 10, 1)), 5, axis=2)
delay = bm.TimeDelay(bm.zeros((10, 5)), delay_len=1., dt=0.1, before_t0=before_t0)
print(delay(0.))

# def test_prev_time_beyond_boundary(self):
@@ -56,3 +72,42 @@ class TestFixedLenDelay(unittest.TestCase):
# delay = bm.FixedLenDelay(3, delay_len=1., dt=0.1, before_t0=lambda t: t)
# delay(-1.2)


class TestLengthDelay(unittest.TestCase):
def test1(self):
dim = 3
delay = bm.LengthDelay(bm.zeros(dim), 10)
print(delay(1))
self.assertTrue(bm.array_equal(delay(1), bm.zeros(dim)))

delay = bm.jit(delay)
print(delay(1))
self.assertTrue(bm.array_equal(delay(1), bm.zeros(dim)))

def test2(self):
dim = 3
delay = bm.LengthDelay(bm.zeros(dim), 10, delay_data=bm.arange(1, 11).reshape((10, 1)))
print(delay(0))
self.assertTrue(bm.array_equal(delay(0), bm.zeros(dim)))
print(delay(1))
self.assertTrue(bm.array_equal(delay(1), bm.ones(dim) * 10))

delay = bm.jit(delay)
print(delay(0))
self.assertTrue(bm.array_equal(delay(0), bm.zeros(dim)))
print(delay(1))
self.assertTrue(bm.array_equal(delay(1), bm.ones(dim) * 10))

def test3(self):
dim = 3
delay = bm.LengthDelay(bm.zeros(dim), 10, delay_data=bm.arange(1, 11).reshape((10, 1)))
print(delay(bm.asarray([1, 2, 3]),
bm.arange(3)))
# self.assertTrue(bm.array_equal(delay(0), bm.zeros(dim)))

delay = bm.jit(delay)
print(delay(bm.asarray([1, 2, 3]),
bm.arange(3)))
# self.assertTrue(bm.array_equal(delay(1), bm.ones(dim) * 10))



+ 2
- 203
brainpy/measure/__init__.py View File

@@ -5,207 +5,6 @@ This module aims to provide commonly used analysis methods for simulated neurona
You can access them through ``brainpy.measure.XXX``.
"""

from .correlation import *
from .firings import *

import numpy as np

from brainpy import tools, math

__all__ = [
'cross_correlation',
'voltage_fluctuation',
'raster_plot',
'firing_rate',
]


# @tools.numba_jit
def _cc(states, i, j):
sqrt_ij = np.sqrt(np.sum(states[i]) * np.sum(states[j]))
k = 0. if sqrt_ij == 0. else np.sum(states[i] * states[j]) / sqrt_ij
return k


def cross_correlation(spikes, bin, dt=None):
r"""Calculate cross correlation index between neurons.

The coherence [1]_ between two neurons i and j is measured by their
cross-correlation of spike trains at zero time lag within a time bin
of :math:`\Delta t = \tau`. More specifically, suppose that a long
time interval T is divided into small bins of :math:`\Delta t` and
that two spike trains are given by :math:`X(l)=` 0 or 1, :math:`Y(l)=` 0
or 1, :math:`l=1,2, \ldots, K(T / K=\tau)`. Thus, we define a coherence
measure for the pair as:

.. math::

\kappa_{i j}(\tau)=\frac{\sum_{l=1}^{K} X(l) Y(l)}
{\sqrt{\sum_{l=1}^{K} X(l) \sum_{l=1}^{K} Y(l)}}

The population coherence measure :math:`\kappa(\tau)` is defined by the
average of :math:`\kappa_{i j}(\tau)` over many pairs of neurons in the
network.

Parameters
----------
spikes :
The history of spike states of the neuron group.
It can be easily get via `StateMonitor(neu, ['spike'])`.
bin : float, int
The time bin to normalize spike states.
dt : float, optional
The time precision.

Returns
-------
cc_index : float
The cross correlation value which represents the synchronization index.

References
----------
.. [1] Wang, Xiao-Jing, and György Buzsáki. "Gamma oscillation by synaptic
inhibition in a hippocampal interneuronal network model." Journal of
neuroscience 16.20 (1996): 6402-6413.
"""
spikes = np.asarray(spikes)
dt = math.get_dt() if dt is None else dt
bin_size = int(bin / dt)
num_hist, num_neu = spikes.shape
num_bin = int(np.ceil(num_hist / bin_size))
if num_bin * bin_size != num_hist:
spikes = np.append(spikes, np.zeros((num_bin * bin_size - num_hist, num_neu)), axis=0)
states = spikes.T.reshape((num_neu, num_bin, bin_size))
states = (np.sum(states, axis=2) > 0.).astype(np.float_)
all_k = []
for i in range(num_neu):
for j in range(i + 1, num_neu):
all_k.append(_cc(states, i, j))
return np.mean(all_k)


# @tools.numba_jit
def _var(neu_signal):
return np.mean(neu_signal * neu_signal) - np.mean(neu_signal) ** 2


def voltage_fluctuation(potentials):
r"""Calculate neuronal synchronization via voltage variance.

The method comes from [1]_ [2]_ [3]_.

First, average over the membrane potential :math:`V`

.. math::

V(t) = \frac{1}{N} \sum_{i=1}^{N} V_i(t)

The variance of the time fluctuations of :math:`V(t)` is

.. math::

\sigma_V^2 = \left\langle \left[ V(t) \right]^2 \right\rangle_t -
\left[ \left\langle V(t) \right\rangle_t \right]^2

where :math:`\left\langle \ldots \right\rangle_t = (1 / T_m) \int_0^{T_m} dt \, \ldots`
denotes time-averaging over a large time, :math:`\tau_m`. After normalization
of :math:`\sigma_V` to the average over the population of the single cell
membrane potentials

.. math::

\sigma_{V_i}^2 = \left\langle\left[ V_i(t) \right]^2 \right\rangle_t -
\left[ \left\langle V_i(t) \right\rangle_t \right]^2

one defines a synchrony measure, :math:`\chi (N)`, for the activity of a system
of :math:`N` neurons by:

.. math::

\chi^2 \left( N \right) = \frac{\sigma_V^2}{ \frac{1}{N} \sum_{i=1}^N
\sigma_{V_i}^2}

Parameters
----------
potentials :
The membrane potential matrix of the neuron group.

Returns
-------
sync_index : float
The synchronization index.

References
----------
.. [1] Golomb, D. and Rinzel J. (1993) Dynamics of globally coupled
inhibitory neurons with heterogeneity. Phys. Rev. reversal_potential 48:4810-4814.
.. [2] Golomb D. and Rinzel J. (1994) Clustering in globally coupled
inhibitory neurons. Physica D 72:259-282.
.. [3] David Golomb (2007) Neuronal synchrony measures. Scholarpedia, 2(1):1347.
"""

potentials = np.asarray(potentials)
num_hist, num_neu = potentials.shape
avg = np.mean(potentials, axis=1)
avg_var = np.mean(avg * avg) - np.mean(avg) ** 2
neu_vars = []
for i in range(num_neu):
neu_vars.append(_var(potentials[:, i]))
var_mean = np.mean(neu_vars)
return avg_var / var_mean if var_mean != 0. else 1.


def raster_plot(sp_matrix, times):
"""Get spike raster plot which displays the spiking activity
of a group of neurons over time.

Parameters
----------
sp_matrix : bnp.ndarray
The matrix which record spiking activities.
times : bnp.ndarray
The time steps.

Returns
-------
raster_plot : tuple
Include (neuron index, spike time).
"""
sp_matrix = np.asarray(sp_matrix)
times = np.asarray(times)
elements = np.where(sp_matrix > 0.)
index = elements[1]
time = times[elements[0]]
return index, time


def firing_rate(sp_matrix, width, dt=None):
r"""Calculate the mean firing rate over in a neuron group.

This method is adopted from Brian2.

The firing rate in trial :math:`k` is the spike count :math:`n_{k}^{sp}`
in an interval of duration :math:`T` divided by :math:`T`:

.. math::

v_k = {n_k^{sp} \over T}

Parameters
----------
sp_matrix : math.JaxArray, np.ndarray
The spike matrix which record spiking activities.
width : int, float
The width of the ``window`` in millisecond.
dt : float, optional
The sample rate.

Returns
-------
rate : numpy.ndarray
The population rate in Hz, smoothed with the given window.
"""
sp_matrix = np.asarray(sp_matrix)
rate = np.sum(sp_matrix, axis=1) / sp_matrix.shape[1]
dt = math.get_dt() if dt is None else dt
width1 = int(width / 2 / dt) * 2 + 1
window = np.ones(width1) * 1000 / width
return np.convolve(rate, window, mode='same')

+ 270
- 0
brainpy/measure/correlation.py View File

@@ -0,0 +1,270 @@
# -*- coding: utf-8 -*-

from functools import partial

import numpy as np
from jax import vmap, jit, lax, numpy as jnp

from brainpy import math as bm

__all__ = [
'cross_correlation',
'voltage_fluctuation',
'matrix_correlation',
'weighted_correlation',
'functional_connectivity',
'functional_connectivity_dynamics',
]


@jit
@partial(vmap, in_axes=(None, 0, 0))
def _cc(states, i, j):
sqrt_ij = jnp.sqrt(jnp.sum(states[i]) * jnp.sum(states[j]))
return lax.cond(sqrt_ij == 0.,
lambda _: 0.,
lambda ij: jnp.sum(states[i] * states[j]) / sqrt_ij,
(i, j))


def cross_correlation(spikes, bin, dt=None):
r"""Calculate cross correlation index between neurons.

The coherence [1]_ between two neurons i and j is measured by their
cross-correlation of spike trains at zero time lag within a time bin
of :math:`\Delta t = \tau`. More specifically, suppose that a long
time interval T is divided into small bins of :math:`\Delta t` and
that two spike trains are given by :math:`X(l)=` 0 or 1, :math:`Y(l)=` 0
or 1, :math:`l=1,2, \ldots, K(T / K=\tau)`. Thus, we define a coherence
measure for the pair as:

.. math::

\kappa_{i j}(\tau)=\frac{\sum_{l=1}^{K} X(l) Y(l)}
{\sqrt{\sum_{l=1}^{K} X(l) \sum_{l=1}^{K} Y(l)}}

The population coherence measure :math:`\kappa(\tau)` is defined by the
average of :math:`\kappa_{i j}(\tau)` over many pairs of neurons in the
network.

Parameters
----------
spikes :
The history of spike states of the neuron group.
It can be easily get via `StateMonitor(neu, ['spike'])`.
bin : float, int
The time bin to normalize spike states.
dt : float, optional
The time precision.

Returns
-------
cc_index : float
The cross correlation value which represents the synchronization index.

References
----------
.. [1] Wang, Xiao-Jing, and György Buzsáki. "Gamma oscillation by synaptic
inhibition in a hippocampal interneuronal network model." Journal of
neuroscience 16.20 (1996): 6402-6413.
"""
spikes = bm.asarray(spikes)
dt = bm.get_dt() if dt is None else dt
bin_size = int(bin / dt)
num_hist, num_neu = spikes.shape
num_bin = int(np.ceil(num_hist / bin_size))
if num_bin * bin_size != num_hist:
spikes = bm.append(spikes, bm.zeros((num_bin * bin_size - num_hist, num_neu)), axis=0)
states = spikes.T.reshape((num_neu, num_bin, bin_size))
states = bm.asarray(bm.sum(states, axis=2) > 0., dtype=jnp.float_)
indices = jnp.tril_indices(4, k=-1)
return jnp.mean(_cc(states.value, *indices))


@partial(vmap, in_axes=(None, 0))
def _var(neu_signal, i):
neu_signal = neu_signal[:, i]
return jnp.mean(neu_signal * neu_signal) - jnp.mean(neu_signal) ** 2


@jit
def voltage_fluctuation(potentials):
r"""Calculate neuronal synchronization via voltage variance.

The method comes from [1]_ [2]_ [3]_.

First, average over the membrane potential :math:`V`

.. math::

V(t) = \frac{1}{N} \sum_{i=1}^{N} V_i(t)

The variance of the time fluctuations of :math:`V(t)` is

.. math::

\sigma_V^2 = \left\langle \left[ V(t) \right]^2 \right\rangle_t -
\left[ \left\langle V(t) \right\rangle_t \right]^2

where :math:`\left\langle \ldots \right\rangle_t = (1 / T_m) \int_0^{T_m} dt \, \ldots`
denotes time-averaging over a large time, :math:`\tau_m`. After normalization
of :math:`\sigma_V` to the average over the population of the single cell
membrane potentials

.. math::

\sigma_{V_i}^2 = \left\langle\left[ V_i(t) \right]^2 \right\rangle_t -
\left[ \left\langle V_i(t) \right\rangle_t \right]^2

one defines a synchrony measure, :math:`\chi (N)`, for the activity of a system
of :math:`N` neurons by:

.. math::

\chi^2 \left( N \right) = \frac{\sigma_V^2}{ \frac{1}{N} \sum_{i=1}^N
\sigma_{V_i}^2}

Parameters
----------
potentials :
The membrane potential matrix of the neuron group.

Returns
-------
sync_index : float
The synchronization index.

References
----------
.. [1] Golomb, D. and Rinzel J. (1993) Dynamics of globally coupled
inhibitory neurons with heterogeneity. Phys. Rev. reversal_potential 48:4810-4814.
.. [2] Golomb D. and Rinzel J. (1994) Clustering in globally coupled
inhibitory neurons. Physica D 72:259-282.
.. [3] David Golomb (2007) Neuronal synchrony measures. Scholarpedia, 2(1):1347.
"""

potentials = bm.as_device_array(potentials)
num_hist, num_neu = potentials.shape
var_mean = jnp.mean(_var(potentials, jnp.arange(num_neu)))
avg = jnp.mean(potentials, axis=1)
avg_var = jnp.mean(avg * avg) - jnp.mean(avg) ** 2
return lax.cond(var_mean != 0.,
lambda _: avg_var / var_mean,
lambda _: 1.,
())


def matrix_correlation(x, y):
"""Pearson correlation of the lower triagonal of two matrices.

The triangular matrix is offset by k = 1 in order to ignore the diagonal line

Parameters
----------
x: tensor
First matrix.
y: tensor
Second matrix

Returns
-------
coef: tensor
Correlation coefficient
"""
x = bm.as_numpy(x)
y = bm.as_numpy(y)
if x.ndim != 2:
raise ValueError(f'Only support 2d tensor, but we got a tensor '
f'with the shape of {x.shape}')
if y.ndim != 2:
raise ValueError(f'Only support 2d tensor, but we got a tensor '
f'with the shape of {y.shape}')
x = x[np.triu_indices_from(x, k=1)]
y = y[np.triu_indices_from(y, k=1)]
cc = np.corrcoef(x, y)[0, 1]
return cc


def functional_connectivity(activities):
"""Functional connectivity matrix of timeseries activities.

Parameters
----------
activities: tensor
The multidimensional tensor with the shape of ``(num_time, num_sample)``.

Returns
-------
connectivity_matrix: tensor
``num_sample x num_sample`` functional connectivity matrix.
"""
activities = bm.as_numpy(activities)
if activities.ndim != 2:
raise ValueError('Only support 2d tensor with shape of "(num_time, num_sample)". '
f'But we got a tensor with the shape of {activities.shape}')
fc = np.corrcoef(activities.T)
return np.nan_to_num(fc)


@jit
def functional_connectivity_dynamics(activities, window_size=30, step_size=5):
"""Computes functional connectivity dynamics (FCD) matrix.

Parameters
----------
activities: tensor
The time series with shape of ``(num_time, num_sample)``.
window_size: int
Size of each rolling window in time steps, defaults to 30.
step_size: int
Step size between each rolling window, defaults to 5.

Returns
-------
fcd_matrix: tensor
FCD matrix.
"""
pass


def _weighted_mean(x, w):
"""Weighted Mean"""
return jnp.sum(x * w) / jnp.sum(w)


def _weighted_cov(x, y, w):
"""Weighted Covariance"""
return jnp.sum(w * (x - _weighted_mean(x, w)) * (y - _weighted_mean(y, w))) / jnp.sum(w)


@jit
def weighted_correlation(x, y, w):
"""Weighted Pearson correlation of two data series.

Parameters
----------
x: tensor
The data series 1.
y: tensor
The data series 2.
w: tensor
Weight vector, must have same length as x and y.

Returns
-------
corr: tensor
Weighted correlation coefficient.
"""
x = bm.as_device_array(x)
y = bm.as_device_array(y)
w = bm.as_device_array(w)
if x.ndim != 1:
raise ValueError(f'Only support 1d tensor, but we got a tensor '
f'with the shape of {x.shape}')
if y.ndim != 1:
raise ValueError(f'Only support 1d tensor, but we got a tensor '
f'with the shape of {y.shape}')
if w.ndim != 1:
raise ValueError(f'Only support 1d tensor, but we got a tensor '
f'with the shape of {w.shape}')
return _weighted_cov(x, y, w) / jnp.sqrt(_weighted_cov(x, x, w) * _weighted_cov(y, y, w))

+ 76
- 0
brainpy/measure/firings.py View File

@@ -0,0 +1,76 @@
# -*- coding: utf-8 -*-

import numpy as np
from jax import jit

from brainpy import math as bm

__all__ = [
'raster_plot',
'firing_rate',
]


def raster_plot(sp_matrix, times):
"""Get spike raster plot which displays the spiking activity
of a group of neurons over time.

Parameters
----------
sp_matrix : bnp.ndarray
The matrix which record spiking activities.
times : bnp.ndarray
The time steps.

Returns
-------
raster_plot : tuple
Include (neuron index, spike time).
"""
sp_matrix = np.asarray(sp_matrix)
times = np.asarray(times)
elements = np.where(sp_matrix > 0.)
index = elements[1]
time = times[elements[0]]
return index, time


@jit
def _firing_rate(sp_matrix, window):
sp_matrix = bm.asarray(sp_matrix)
rate = bm.sum(sp_matrix, axis=1) / sp_matrix.shape[1]
return bm.convolve(rate, window, mode='same')


def firing_rate(sp_matrix, width, dt=None, numpy=True):
r"""Calculate the mean firing rate over in a neuron group.

This method is adopted from Brian2.

The firing rate in trial :math:`k` is the spike count :math:`n_{k}^{sp}`
in an interval of duration :math:`T` divided by :math:`T`:

.. math::

v_k = {n_k^{sp} \over T}

Parameters
----------
sp_matrix : math.JaxArray, np.ndarray
The spike matrix which record spiking activities.
width : int, float
The width of the ``window`` in millisecond.
dt : float, optional
The sample rate.

Returns
-------
rate : numpy.ndarray
The population rate in Hz, smoothed with the given window.
"""
dt = bm.get_dt() if (dt is None) else dt
width1 = int(width / 2 / dt) * 2 + 1
window = bm.ones(width1) * 1000 / width
fr = _firing_rate(sp_matrix, window)
return fr.numpy() if numpy else fr


+ 59
- 0
brainpy/measure/tests/test_correlation.py View File

@@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-


import unittest
import brainpy as bp


class TestCrossCorrelation(unittest.TestCase):
def test_cc(self):
spikes = bp.math.ones((1000, 10))
cc1 = bp.measure.cross_correlation(spikes, 1.)
self.assertTrue(cc1 == 1.)

spikes = bp.math.zeros((1000, 10))
cc2 = bp.measure.cross_correlation(spikes, 1.)
self.assertTrue(cc2 == 0.)

def test_cc2(self):
spikes = bp.math.random.randint(0, 2, (1000, 10))
print(bp.measure.cross_correlation(spikes, 1.))
print(bp.measure.cross_correlation(spikes, 0.5))

def test_cc3(self):
spikes = bp.math.random.random((1000, 100)) < 0.8
print(bp.measure.cross_correlation(spikes, 1.))
print(bp.measure.cross_correlation(spikes, 0.5))

def test_cc4(self):
spikes = bp.math.random.random((1000, 100)) < 0.2
print(bp.measure.cross_correlation(spikes, 1.))
print(bp.measure.cross_correlation(spikes, 0.5))

def test_cc5(self):
spikes = bp.math.random.random((1000, 100)) < 0.05
print(bp.measure.cross_correlation(spikes, 1.))
print(bp.measure.cross_correlation(spikes, 0.5))


class TestVoltageFluctuation(unittest.TestCase):
def test_vf1(self):
voltages = bp.math.random.normal(0, 10, size=(1000, 100))
print(bp.measure.voltage_fluctuation(voltages))

voltages = bp.math.ones((1000, 100))
print(bp.measure.voltage_fluctuation(voltages))


class TestFunctionalConnectivity(unittest.TestCase):
def test_cf1(self):
act = bp.math.random.random((10000, 3))
print(bp.measure.functional_connectivity(act))


class TestMatrixCorrelation(unittest.TestCase):
def test_mc(self):
A = bp.math.random.random((100, 100))
B = bp.math.random.random((100, 100))
print(bp.measure.matrix_correlation(A, B))


+ 22
- 0
brainpy/measure/tests/test_firings.py View File

@@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-


import unittest
import brainpy as bp


class TestFiringRate(unittest.TestCase):
def test_fr1(self):
spikes = bp.math.ones((1000, 10))
print(bp.measure.firing_rate(spikes, 1.))

def test_fr2(self):
spikes = bp.math.random.random((1000, 10)) < 0.2
print(bp.measure.firing_rate(spikes, 1.))
print(bp.measure.firing_rate(spikes, 10.))

def test_fr3(self):
spikes = bp.math.random.random((1000, 10)) < 0.02
print(bp.measure.firing_rate(spikes, 1.))
print(bp.measure.firing_rate(spikes, 5.))


+ 314
- 219
brainpy/nn/base.py View File

@@ -12,7 +12,7 @@ This module provide basic Node class for whole ``brainpy.nn`` system.
This means ``brainpy.nn.Network`` is only used to pack element nodes. It will be
never be an element node.
- ``brainpy.nn.FrozenNetwork``: The whole network which can be represented as a basic
elementary node when composing a larger network. TODO
elementary node when composing a larger network (TODO).
"""

from copy import copy, deepcopy
@@ -48,6 +48,16 @@ __all__ = [

NODE_STATES = ['inputs', 'feedbacks', 'state', 'output']

SUPPORTED_LAYOUTS = ['shell_layout',
'multipartite_layout',
'spring_layout',
'spiral_layout',
'spectral_layout',
'random_layout',
'planar_layout',
'kamada_kawai_layout',
'circular_layout']


def not_implemented(fun: Callable) -> Callable:
"""Marks the given module method is not implemented.
@@ -92,8 +102,10 @@ class Node(Base):
self._is_ff_initialized = False
self._is_fb_initialized = False
self._is_state_initialized = False
self._is_fb_state_initialized = False
self._trainable = trainable
self._state = None # the state of the current node
self._fb_output = None # the feedback output of the current node
# data pass function
if self.data_pass_type not in DATA_PASS_FUNC:
raise ValueError(f'Unsupported data pass type {self.data_pass_type}. '
@@ -111,12 +123,9 @@ class Node(Base):
name = type(self).__name__
prefix = ' ' * (len(name) + 1)
line1 = (f"{name}(name={self.name}, "
f"trainable={self.trainable}, "
f"forwards={self.feedforward_shapes}, "
f"feedbacks={self.feedback_shapes}, \n")
line2 = (f"{prefix}output={self.output_shape}, "
f"support_feedback={self.support_feedback}, "
f"data_pass_type={self.data_pass_type})")
line2 = f"{prefix}output={self.output_shape}"
return line1 + line2

def __call__(self, *args, **kwargs) -> Tensor:
@@ -194,7 +203,7 @@ class Node(Base):
@property
def state(self) -> Optional[Tensor]:
"""Node current internal state."""
if self.is_ff_initialized:
if self._is_ff_initialized:
return self._state
return None

@@ -209,9 +218,9 @@ class Node(Base):

This method allows the maximum flexibility to change the
node state. It can set a new data (same shape, same dtype)
to the state. It can also set the data with another batch size.
We highly recommend the user to use this function.
to the state. It can also set a new data with the different
shape. We highly recommend the user to use this function.
instead of using ``self.state.value``.
"""
if self.state is None:
if self.output_shape is not None:
@@ -225,31 +234,52 @@ class Node(Base):
self.state._value = bm.as_device_array(state)

@property
def trainable(self) -> bool:
"""Returns if the Node can be trained."""
return self._trainable
def fb_output(self) -> Optional[Tensor]:
return self._fb_output

@property
def is_ff_initialized(self) -> bool:
return self._is_ff_initialized
@fb_output.setter
def fb_output(self, value: Tensor):
raise NotImplementedError('Please use "set_fb_output()" to reset the node feedback state, '
'or use "self.fb_output.value" to change the state content.')

@is_ff_initialized.setter
def is_ff_initialized(self, value: bool):
assert isinstance(value, bool)
self._is_ff_initialized = value
def set_fb_output(self, state: Tensor):
"""
Safely set the feedback state of the node.

@property
def is_fb_initialized(self) -> bool:
return self._is_fb_initialized
This method allows the maximum flexibility to change the
node state. It can set a new data (same shape, same dtype)
to the state. It can also set a new data with the different
shape. We highly recommend the user to use this function.
instead of using ``self.fb_output.value``.
"""
if self.fb_output is None:
if self.output_shape is not None:
check_batch_shape(self.output_shape, state.shape)
self._fb_output = bm.Variable(state) if not isinstance(state, bm.Variable) else state
else:
check_batch_shape(self.fb_output.shape, state.shape)
if self.fb_output.dtype != state.dtype:
raise MathError('Cannot set the feedback state, because the dtype is '
f'not consistent: {self.fb_output.dtype} != {state.dtype}')
self.fb_output._value = bm.as_device_array(state)

@is_fb_initialized.setter
def is_fb_initialized(self, value: bool):
assert isinstance(value, bool)
self._is_fb_initialized = value
@property
def trainable(self) -> bool:
"""Returns if the Node can be trained."""
return self._trainable

@property
def is_state_initialized(self):
return self._is_state_initialized
def is_initialized(self) -> bool:
if self._is_ff_initialized and self._is_state_initialized:
if self.feedback_shapes is not None:
if self._is_fb_initialized and self._is_fb_state_initialized:
return True
else:
return False
else:
return True
else:
return False

@trainable.setter
def trainable(self, value: bool):
@@ -268,7 +298,7 @@ class Node(Base):
self.set_feedforward_shapes(size)

def set_feedforward_shapes(self, feedforward_shapes: Dict):
if not self.is_ff_initialized:
if not self._is_ff_initialized:
check_dict_data(feedforward_shapes,
key_type=(Node, str),
val_type=(list, tuple),
@@ -278,11 +308,11 @@ class Node(Base):
if self.feedforward_shapes is not None:
for key, size in self._feedforward_shapes.items():
if key not in feedforward_shapes:
raise ValueError(f"Impossible to reset the input data of {self.name}. "
raise ValueError(f"Impossible to reset the input shape of {self.name}. "
f"Because this Node has the input dimension {size} from {key}. "
f"While we do not find it in the given feedforward_shapes")
if not check_batch_shape(size, feedforward_shapes[key], mode='bool'):
raise ValueError(f"Impossible to reset the input data of {self.name}. "
raise ValueError(f"Impossible to reset the input shape of {self.name}. "
f"Because this Node has the input dimension {size} from {key}. "
f"While the give shape is {feedforward_shapes[key]}")

@@ -296,7 +326,7 @@ class Node(Base):
self.set_feedback_shapes(size)

def set_feedback_shapes(self, fb_shapes: Dict):
if not self.is_fb_initialized:
if not self._is_fb_initialized:
check_dict_data(fb_shapes, key_type=(Node, str), val_type=(tuple, list), name='fb_shapes')
self._feedback_shapes = fb_shapes
else:
@@ -321,14 +351,21 @@ class Node(Base):
self.set_output_shape(size)

@property
def support_feedback(self):
if hasattr(self.init_fb, 'not_implemented'):
if self.init_fb.not_implemented:
def is_feedback_input_supported(self):
if hasattr(self.init_fb_conn, 'not_implemented'):
if self.init_fb_conn.not_implemented:
return False
return True

@property
def is_feedback_supported(self):
if self.fb_output is None:
return False
else:
return True

def set_output_shape(self, shape: Sequence[int]):
if not self.is_ff_initialized:
if not self._is_ff_initialized:
if not isinstance(shape, (tuple, list)):
raise ValueError(f'Must be a sequence of int, but got {shape}')
self._output_shape = tuple(shape)
@@ -368,84 +405,88 @@ class Node(Base):
new_obj.name = self.unique_name(name or (self.name + '_copy'))
return new_obj

def _ff_init(self):
if not self.is_ff_initialized:
def _init_ff_conn(self):
if not self._is_ff_initialized:
try:
self.init_ff()
self.init_ff_conn()
except Exception as e:
raise ModelBuildError(f'{self.name} initialization failed.') from e
self._is_ff_initialized = True
if self.output_shape is None:
raise ValueError(f'Please set the output shape when implementing '
f'"init_ff()" of the node {self.name}')

def _fb_init(self):
if not self.is_fb_initialized:
def _init_fb_conn(self):
if not self._is_fb_initialized:
try:
self.init_fb()
self.init_fb_conn()
except Exception as e:
raise ModelBuildError(f"{self.name} initialization failed.") from e
self._is_fb_initialized = True

@not_implemented
def init_fb(self):
def init_fb_conn(self):
"""Initialize the feedback connections.
This function will be called only once."""
raise ValueError(f'This node \n\n{self} \n\ndoes not support feedback connection.')

def init_ff(self):
def init_ff_conn(self):
"""Initialize the feedforward connections.
This function will be called only once."""
raise NotImplementedError('Please implement the feedforward initialization.')

def init_state(self, num_batch=1):
def _init_state(self, num_batch=1):
state = self.init_state(num_batch)
if state is not None:
self.set_state(state)

def _init_fb_output(self, num_batch=1):
output = self.init_fb_output(num_batch)
if output is not None:
self.set_fb_output(output)

def init_state(self, num_batch=1) -> Optional[Tensor]:
"""Set the initial node state.

This function can be called multiple times."""
pass

def initialize(self,
ff: Optional[Union[Tensor, Dict[Any, Tensor]]] = None,
fb: Optional[Union[Tensor, Dict[Any, Tensor]]] = None,
num_batch: int = None):
def init_fb_output(self, num_batch=1) -> Optional[Tensor]:
"""Set the initial node feedback state.

This function can be called multiple times. However,
it is only triggered when the node has feedback connections.
"""
Initialize the whole network. This function must be called before applying JIT.
return bm.zeros((num_batch,) + self.output_shape[1:], dtype=bm.float_)

This function is useful, because it is independent from the __call__ function.
We can use this function before we applying JIT to __call__ function.
def initialize(self, num_batch: int):
"""
Initialize the node. This function must be called before applying JIT.

# feedforward initialization
if not self.is_ff_initialized:
# feedforward data
if ff is None:
if self._feedforward_shapes is None:
raise ValueError('Cannot initialize this node, because we detect '
'both "feedforward_shapes"and "ff" inputs are None. ')
in_sizes = self._feedforward_shapes
if num_batch is None:
raise ValueError('"num_batch" cannot be None when "ff" is not provided.')
check_integer(num_batch, 'num_batch', min_bound=0, allow_none=False)
else:
if isinstance(ff, (bm.ndarray, jnp.ndarray)):
ff = {self.name: ff}
assert isinstance(ff, dict), f'"ff" must be a dict or a tensor, got {type(ff)}: {ff}'
assert self.name in ff, f'Cannot find input for this node \n\n{self} \n\nwhen given "ff" {ff}'
batch_sizes = [v.shape[0] for v in ff.values()]
if set(batch_sizes) != 1:
raise ValueError('Batch sizes must be consistent, but we got multiple '
f'batch sizes {set(batch_sizes)} for the given input: \n'
f'{ff}')
in_sizes = {k: (None,) + v.shape[1:] for k, v in ff.items()}
if (num_batch is not None) and (num_batch != batch_sizes[0]):
raise ValueError(f'The provided "num_batch" {num_batch} is consistent with the '
f'batch size of the provided data {batch_sizes[0]}')

# initialize feedforward
self.set_feedforward_shapes(in_sizes)
self._ff_init()
self.init_state(num_batch)
self._is_state_initialized = True
This function is useful, because it is independent of the __call__ function.
We can use this function before we apply JIT to __call__ function.
"""

# feedback initialization
if fb is not None:
if not self.is_fb_initialized: # initialize feedback
assert isinstance(fb, dict), f'"fb" must be a dict, got {type(fb)}'
fb_sizes = {k: (None,) + v.shape[1:] for k, v in fb.items()}
self.set_feedback_shapes(fb_sizes)
self._fb_init()
else:
self._is_fb_initialized = True
# feedforward initialization
if self.feedforward_shapes is None:
raise ValueError('Cannot initialize this node, because we detect '
'both "feedforward_shapes" is None. '
'Two ways can solve this problem:\n\n'
'1. Connecting an instance of "brainpy.nn.Input()" to this node. \n'
'2. Providing the "input_shape" when initialize the node.')
check_integer(num_batch, 'num_batch', min_bound=0, allow_none=False)
self._init_ff_conn()

# initialize state
self._init_state(num_batch)
self._is_state_initialized = True

if self.feedback_shapes is not None:
# feedback initialization
self._init_fb_conn()
# initialize feedback state
self._init_fb_output(num_batch)
self._is_fb_state_initialized = True

def _check_inputs(self, ff, fb=None):
# check feedforward inputs
@@ -477,9 +518,8 @@ class Node(Base):
forced_feedbacks: Dict[str, Tensor] = None,
monitors=None,
**kwargs) -> Union[Tensor, Tuple[Tensor, Dict]]:
# # initialization
# self.initialize(ff, fb)
if not (self.is_ff_initialized and self.is_fb_initialized and self.is_state_initialized):
# checking
if not self.is_initialized:
raise ValueError('Please initialize the Node first by calling "initialize()" function.')

# initialize the forced data
@@ -511,6 +551,7 @@ class Node(Base):
assert self.state is not None, (f'{self} \n\nhas no state, while '
f'the user try to monitor its state.')
state_monitors[key] = None

# calling
ff, fb = self._check_inputs(ff, fb=fb)
if 'inputs' in state_monitors:
@@ -528,7 +569,7 @@ class Node(Base):
else:
return output

def forward(self, ff, fb=None, **kwargs):
def forward(self, ff, fb=None, **shared_kwargs):
"""The feedforward computation function of a node.

Parameters
@@ -537,7 +578,7 @@ class Node(Base):
The feedforward inputs.
fb: optional, tensor, dict, sequence
The feedback inputs.
**kwargs
**shared_kwargs
Other parameters.

Returns
@@ -547,12 +588,12 @@ class Node(Base):
"""
raise NotImplementedError

def feedback(self, **kwargs):
def feedback(self, ff_output, **shared_kwargs):
"""The feedback computation function of a node.

Parameters
----------
**kwargs
**shared_kwargs
Other global parameters.

Returns
@@ -560,12 +601,17 @@ class Node(Base):
Tensor
A feedback output tensor value.
"""
return self.state
return ff_output


class RecurrentNode(Node):
"""
Basic class for recurrent node.

The supports for the recurrent node are:

- Self-connection when using ``plot_node_graph()`` function
- Set trainable state with ``state_trainable=True``.
"""

def __init__(self,
@@ -617,19 +663,6 @@ class RecurrentNode(Node):
else:
self.state._value = bm.as_device_array(state)

def __repr__(self):
name = type(self).__name__
prefix = ' ' * (len(name) + 1)

line1 = (f"{name}(name={self.name}, recurrent=True, "
f"trainable={self.trainable}, \n")
line2 = (f"{prefix}forwards={self.feedforward_shapes}, "
f"feedbacks={self.feedback_shapes}, \n")
line3 = (f"{prefix}output={self.output_shape}, "
f"support_feedback={self.support_feedback}, "
f"data_pass_type={self.data_pass_type})")
return line1 + line2 + line3


class Network(Node):
"""Basic Network class for neural network building in BrainPy."""
@@ -806,8 +839,8 @@ class Network(Node):

def replace_graph(self,
nodes: Sequence[Node],
ff_edges: Sequence[Tuple[Node, Node]],
fb_edges: Sequence[Tuple[Node, Node]] = None) -> "Network":
ff_edges: Sequence[Tuple[Node, ...]],
fb_edges: Sequence[Tuple[Node, ...]] = None) -> "Network":
if fb_edges is None: fb_edges = tuple()

# assign nodes and edges
@@ -817,16 +850,45 @@ class Network(Node):
self._network_init()
return self

def init_ff(self):
def set_output_shape(self, shape: Dict[str, Sequence[int]]):
# check shape
if not isinstance(shape, dict):
raise ValueError(f'Must be a dict of <node name, shape>, but got {type(shape)}: {shape}')
for key, val in shape.items():
if not isinstance(val, (tuple, list)):
raise ValueError(f'Must be a sequence of int, but got {val} for key "{key}"')
# for s in val:
# if not (isinstance(s, int) or (s is None)):
# raise ValueError(f'Must be a sequence of int, but got {val}')

if not self._is_ff_initialized:
if len(self.exit_nodes) == 1:
self._output_shape = tuple(shape.values())[0]
else:
self._output_shape = shape
else:
for val in shape.values():
check_batch_shape(val, self.output_shape)

def init_ff_conn(self):
"""Initialize the feedforward connections of the network.
This function will be called only once."""
# input shapes of entry nodes
for node in self.entry_nodes:
# set ff shapes
if node.feedforward_shapes is None:
if self.feedforward_shapes is None:
raise ValueError('Cannot find the input size. '
'Cannot initialize the network.')
else:
node.set_feedforward_shapes({node.name: self._feedforward_shapes[node.name]})
node._ff_init()
# set fb shapes
if node in self.fb_senders:
fb_shapes = {node: node.output_shape for node in self.fb_senders.get(node, [])}
if None not in fb_shapes.values():
node.set_feedback_shapes(fb_shapes)
# init ff conn
node._init_ff_conn()

# initialize the data
children_queue = []
@@ -840,49 +902,79 @@ class Network(Node):
children_queue.append(child)
while len(children_queue):
node = children_queue.pop(0)
# initialize input and output sizes
# set ff shapes
parent_sizes = {p: p.output_shape for p in self.ff_senders.get(node, [])}
node.set_feedforward_shapes(parent_sizes)
node._ff_init()
if node in self.fb_senders:
# set fb shapes
fb_shapes = {node: node.output_shape for node in self.fb_senders.get(node, [])}
if None not in fb_shapes.values():
node.set_feedback_shapes(fb_shapes)
# init ff conn
node._init_ff_conn()
# append children
for child in self.ff_receivers.get(node, []):
ff_senders[child].remove(node)
if len(ff_senders.get(child, [])) == 0:
children_queue.append(child)

def init_fb(self):
# set output shape
out_sizes = {node: node.output_shape for node in self.exit_nodes}
self.set_output_shape(out_sizes)

def init_fb_conn(self):
"""Initialize the feedback connections of the network.
This function will be called only once."""
for receiver, senders in self.fb_senders.items():
fb_sizes = {node: node.output_shape for node in senders}
if None in fb_sizes.values():
none_size_nodes = [repr(n) for n, v in fb_sizes.items() if v is None]
none_size_nodes = "\n".join(none_size_nodes)
raise ValueError(f'Output shapes of nodes \n\n'
f'{none_size_nodes}\n\n'
f'have not been initialized, '
f'leading us cannot initialize the '
f'feedback connection of node \n\n'
f'{receiver}')
receiver.set_feedback_shapes(fb_sizes)
receiver._fb_init()
receiver._init_fb_conn()

def init_state(self, num_batch=1):
"""Initialize the states of all children nodes."""
def _init_state(self, num_batch=1):
"""Initialize the states of all children nodes.
This function can be called multiple times."""
for node in self.lnodes:
node.init_state(num_batch)
node._init_state(num_batch)

def _init_fb_output(self, num_batch=1):
"""Initialize the node feedback state.

def initialize(self,
ff: Optional[Union[Tensor, Dict[Any, Tensor]]] = None,
fb: Optional[Union[Tensor, Dict[Any, Tensor]]] = None,
num_batch: int = None):
This function can be called multiple times. However,
it is only triggered when the node has feedback connections.
"""
for node in self.feedback_nodes:
node._init_fb_output(num_batch)

def initialize(self, num_batch: int):
"""
Initialize the whole network. This function must be called before applying JIT.

This function is useful, because it is independent from the __call__ function.
We can use this function before we applying JIT to __call__ function.
This function is useful, because it is independent of the __call__ function.
We can use this function before we apply JIT to __call__ function.
"""

# feedforward initialization
if not self.is_ff_initialized:
# set feedforward shapes
if not self._is_ff_initialized:
# check input and output nodes
assert len(self.entry_nodes) > 0, (f"We found this network \n\n"
f"{self} "
f"\n\nhas no input nodes.")
assert len(self.exit_nodes) > 0, (f"We found this network \n\n"
f"{self} "
f"\n\nhas no output nodes.")

# check whether has a feedforward path for each feedback pair
if len(self.entry_nodes) <= 0:
raise ValueError(f"We found this network \n\n"
f"{self} "
f"\n\nhas no input nodes.")
if len(self.exit_nodes) <= 0:
raise ValueError(f"We found this network \n\n"
f"{self} "
f"\n\nhas no output nodes.")

# check whether it has a feedforward path for each feedback pair
ff_edges = [(a.name, b.name) for a, b in self.ff_edges]
for node, receiver in self.fb_edges:
if not detect_path(receiver.name, node.name, ff_edges):
@@ -895,49 +987,42 @@ class Network(Node):
f'feedforward connection between them. ')

# feedforward checking
if ff is None:
in_sizes = dict()
for node in self.entry_nodes:
if node._feedforward_shapes is None:
raise ValueError('Cannot initialize this node, because we detect '
'both "feedforward_shapes" and "ff" inputs are None. '
'Maybe you need a brainpy.nn.Input instance '
'to instruct the input size.')
in_sizes.update(node._feedforward_shapes)
if num_batch is None:
raise ValueError('"num_batch" cannot be None when "ff" is not provided.')
check_integer(num_batch, 'num_batch', min_bound=0, allow_none=False)

else:
if isinstance(ff, (bm.ndarray, jnp.ndarray)):
ff = {self.entry_nodes[0].name: ff}
assert isinstance(ff, dict), f'ff must be a dict or a tensor, got {type(ff)}: {ff}'
for n in self.entry_nodes:
if n.name not in ff:
raise ValueError(f'Cannot find the input of the node {n}')
batch_sizes = [v.shape[0] for v in ff.values()]
if len(set(batch_sizes)) != 1:
raise ValueError('Batch sizes must be consistent, but we got multiple '
f'batch sizes {set(batch_sizes)} for the given input: \n'
f'{ff}')
in_sizes = {k: (None,) + v.shape[1:] for k, v in ff.items()}
if (num_batch is not None) and (num_batch != batch_sizes[0]):
raise ValueError(f'The provided "num_batch" {num_batch} is consistent with the '
f'batch size of the provided data {batch_sizes[0]}')

# initialize feedforward
in_sizes = dict()
for node in self.entry_nodes:
if node.feedforward_shapes is None:
raise ValueError('Cannot initialize this node, because we detect '
'"feedforward_shapes" is None. '
'Maybe you need a brainpy.nn.Input instance '
'to instruct the input size.')
in_sizes.update(node._feedforward_shapes)
self.set_feedforward_shapes(in_sizes)
self._ff_init()
self.init_state(num_batch)
self._is_state_initialized = True

# feedforward initialization
if self.feedforward_shapes is None:
raise ValueError('Cannot initialize this node, because we detect '
'both "feedforward_shapes" is None. ')
check_integer(num_batch, 'num_batch', min_bound=1, allow_none=False)
self._init_ff_conn()

# initialize state
self._init_state(num_batch)
self._is_state_initialized = True

# set feedback shapes
if not self._is_fb_initialized:
if len(self.fb_senders) > 0:
fb_sizes = dict()
for sender in self.fb_senders.keys():
fb_sizes[sender] = sender.output_shape
self.set_feedback_shapes(fb_sizes)

# feedback initialization
if len(self.fb_senders):
# initialize feedback
if not self.is_fb_initialized:
self._fb_init()
else:
self.is_fb_initialized = True
if self.feedback_shapes is not None:
self._init_fb_conn()
# initialize feedback state
self._init_fb_output(num_batch)
self._is_fb_state_initialized = True

def _check_inputs(self, ff, fb=None):
# feedforward inputs
@@ -986,8 +1071,7 @@ class Network(Node):
monitors: Optional[Sequence[str]] = None,
**kwargs):
# initialization
# self.initialize(ff, fb)
if not (self.is_ff_initialized and self.is_fb_initialized and self.is_state_initialized):
if not self.is_initialized:
raise ValueError('Please initialize the Network first by calling "initialize()" function.')

# initialize the forced data
@@ -1038,7 +1122,7 @@ class Network(Node):
forced_states: Dict[str, Tensor] = None,
forced_feedbacks: Dict[str, Tensor] = None,
monitors: Dict = None,
**kwargs):
**shared_kwargs):
"""The main computation function of a network.

Parameters
@@ -1053,7 +1137,7 @@ class Network(Node):
The fixed feedback for the nodes in the network.
monitors: optional, sequence
Can be used to monitor the state or the attribute of a node in the network.
**kwargs
**shared_kwargs
Other parameters which will be parsed into every node.

Returns
@@ -1077,10 +1161,11 @@ class Network(Node):
parent_outputs = {}
for i, node in enumerate(self._entry_nodes):
ff_ = {node.name: ff[i]}
fb_ = {p: (forced_feedbacks[p.name] if (p.name in forced_feedbacks) else p.feedback())
fb_ = {p: (forced_feedbacks[p.name] if (p.name in forced_feedbacks) else p.fb_output)
for p in self.fb_senders.get(node, [])}
self._call_a_node(node, ff_, fb_, monitors, forced_states,
parent_outputs, children_queue, ff_senders, **kwargs)
parent_outputs, children_queue, ff_senders,
**shared_kwargs)
runned_nodes.add(node.name)

# run the model
@@ -1088,23 +1173,23 @@ class Network(Node):
node = children_queue.pop(0)
# get feedforward and feedback inputs
ff = {p: parent_outputs[p] for p in self.ff_senders.get(node, [])}
fb = {p: (forced_feedbacks[p.name] if (p.name in forced_feedbacks) else p.feedback())
fb = {p: (forced_feedbacks[p.name] if (p.name in forced_feedbacks) else p.fb_output)
for p in self.fb_senders.get(node, [])}
# call the node
self._call_a_node(node, ff, fb, monitors, forced_states,
parent_outputs, children_queue, ff_senders,
**kwargs)
# #- remove unnecessary parent outputs -#
# needed_parents = []
# runned_nodes.add(node.name)
# for child in (all_nodes - runned_nodes):
# for parent in self.ff_senders[self.implicit_nodes[child]]:
# needed_parents.append(parent.name)
# for parent in list(parent_outputs.keys()):
# _name = parent.name
# if _name not in needed_parents and _name not in output_nodes:
# parent_outputs.pop(parent)
**shared_kwargs)
# - remove unnecessary parent outputs - #
needed_parents = []
runned_nodes.add(node.name)
for child in (all_nodes - runned_nodes):
for parent in self.ff_senders[self.implicit_nodes[child]]:
needed_parents.append(parent.name)
for parent in list(parent_outputs.keys()):
_name = parent.name
if _name not in needed_parents and _name not in output_nodes:
parent_outputs.pop(parent)

# returns
if len(self.exit_nodes) > 1:
@@ -1114,7 +1199,8 @@ class Network(Node):
return state, monitors

def _call_a_node(self, node, ff, fb, monitors, forced_states,
parent_outputs, children_queue, ff_senders, **kwargs):
parent_outputs, children_queue, ff_senders,
**shared_kwargs):
ff = node.data_pass_func(ff)
if f'{node.name}.inputs' in monitors:
monitors[f'{node.name}.inputs'] = ff
@@ -1123,12 +1209,17 @@ class Network(Node):
fb = node.data_pass_func(fb)
if f'{node.name}.feedbacks' in monitors:
monitors[f'{node.name}.feedbacks'] = fb
parent_outputs[node] = node.forward(ff, fb, **kwargs)
parent_outputs[node] = node.forward(ff, fb, **shared_kwargs)
else:
parent_outputs[node] = node.forward(ff, **kwargs)
if node.name in forced_states: # forced state
parent_outputs[node] = node.forward(ff, **shared_kwargs)
# get the feedback state
if node in self.fb_receivers:
node.set_fb_output(node.feedback(parent_outputs[node], **shared_kwargs))
# forced state
if node.name in forced_states:
node.state.value = forced_states[node.name]
parent_outputs[node] = forced_states[node.name]
# parent_outputs[node] = forced_states[node.name]
# monitor the values
if f'{node.name}.state' in monitors:
monitors[f'{node.name}.state'] = node.state.value
if f'{node.name}.output' in monitors:
@@ -1143,7 +1234,7 @@ class Network(Node):
fig_size: tuple = (10, 10),
node_size: int = 2000,
arrow_size: int = 20,
layout='spectral_layout'):
layout='shell_layout'):
"""Plot the node graph based on NetworkX package

Parameters
@@ -1155,7 +1246,17 @@ class Network(Node):
arrow_size:int, default to 20
The size of the arrow
layout: str
The graph layout. More please see networkx Graph Layout.
The graph layout. The supported layouts are:

- "shell_layout"
- "multipartite_layout"
- "spring_layout"
- "spiral_layout"
- "spectral_layout"
- "random_layout"
- "planar_layout"
- "kamada_kawai_layout"
- "circular_layout"
"""
try:
import networkx as nx
@@ -1204,15 +1305,8 @@ class Network(Node):
G.add_edges_from(fb_edges)
G.add_edges_from(rec_edges)

assert layout in ['shell_layout',
'multipartite_layout',
'spring_layout',
'spiral_layout',
'spectral_layout',
'random_layout',
'planar_layout',
'kamada_kawai_layout',
'circular_layout']
if layout not in SUPPORTED_LAYOUTS:
raise UnsupportedError(f'Only support layouts: {SUPPORTED_LAYOUTS}')
layout = getattr(nx, layout)(G)

plt.figure(figsize=fig_size)
@@ -1252,10 +1346,12 @@ class Network(Node):
proxie = []
labels = []
if len(nodes_trainable):
proxie.append(Line2D([], [], color='white', marker='o', markerfacecolor=trainable_color))
proxie.append(Line2D([], [], color='white', marker='o',
markerfacecolor=trainable_color))
labels.append('Trainable')
if len(nodes_untrainable):
proxie.append(Line2D([], [], color='white', marker='o', markerfacecolor=untrainable_color))
proxie.append(Line2D([], [], color='white', marker='o',
markerfacecolor=untrainable_color))
labels.append('Untrainable')
if len(ff_edges):
proxie.append(Line2D([], [], color=ff_color, linewidth=2))
@@ -1267,8 +1363,7 @@ class Network(Node):
proxie.append(Line2D([], [], color=rec_color, linewidth=2))
labels.append('Recurrent')

plt.legend(proxie, labels, scatterpoints=1, markerscale=2,
loc='best')
plt.legend(proxie, labels, scatterpoints=1, markerscale=2, loc='best')
plt.tight_layout()
plt.show()



+ 3
- 4
brainpy/nn/nodes/ANN/conv.py View File

@@ -4,9 +4,8 @@
import jax.lax

import brainpy.math as bm
from brainpy.initialize import XavierNormal, ZeroInit
from brainpy.initialize import XavierNormal, ZeroInit, init_param
from brainpy.nn.base import Node
from brainpy.nn.utils import init_param

__all__ = [
'Conv2D',
@@ -81,7 +80,7 @@ class Conv2D(Node):
self.padding = padding
self.groups = groups

def init_ff(self):
def init_ff_conn(self):
assert self.num_input % self.groups == 0, '"nin" should be divisible by groups'
size = _check_tuple(self.kernel_size) + (self.num_input // self.groups, self.num_output)
self.w = init_param(self.w_init, size)
@@ -90,7 +89,7 @@ class Conv2D(Node):
self.w = bm.TrainVar(self.w)
self.b = bm.TrainVar(self.b)

def forward(self, ff, **kwargs):
def forward(self, ff, **shared_kwargs):
x = ff[0]
nin = self.w.value.shape[2] * self.groups
assert x.shape[1] == nin, (f'Attempting to convolve an input with {x.shape[1]} input channels '


+ 3
- 3
brainpy/nn/nodes/ANN/dropout.py View File

@@ -44,11 +44,11 @@ class Dropout(Node):
self.prob = prob
self.rng = bm.random.RandomState(seed=seed)

def init_ff(self):
def init_ff_conn(self):
self.set_output_shape(self.feedforward_shapes)

def forward(self, ff, **kwargs):
if kwargs.get('train', True):
def forward(self, ff, **shared_kwargs):
if shared_kwargs.get('train', True):
keep_mask = self.rng.bernoulli(self.prob, ff.shape)
return bm.where(keep_mask, ff / self.prob, 0.)
else:


+ 94
- 64
brainpy/nn/nodes/ANN/rnn_cells.py View File

@@ -1,14 +1,20 @@
# -*- coding: utf-8 -*-


from typing import Union, Callable

import brainpy.math as bm
from brainpy.initialize import (XavierNormal, ZeroInit,
Uniform, Orthogonal)
from brainpy.initialize import (XavierNormal,
ZeroInit,
Uniform,
Orthogonal,
init_param,
Initializer)
from brainpy.nn.base import RecurrentNode
from brainpy.nn.utils import init_param
from brainpy.tools.checking import (check_integer,
check_initializer,
check_shape_consistency)
from brainpy.types import Tensor

__all__ = [
'VanillaRNN',
@@ -33,19 +39,21 @@ class VanillaRNN(RecurrentNode):
def __init__(
self,
num_unit: int,
state_initializer=Uniform(),
wi_initializer=XavierNormal(),
wh_initializer=XavierNormal(),
bias_initializer=ZeroInit(),
activation='relu',
trainable=True,
state_initializer: Union[Tensor, Callable, Initializer] = Uniform(),
wi_initializer: Union[Tensor, Callable, Initializer] = XavierNormal(),
wh_initializer: Union[Tensor, Callable, Initializer] = XavierNormal(),
bias_initializer: Union[Tensor, Callable, Initializer] = ZeroInit(),
activation: str = 'relu',
trainable: bool = True,
**kwargs
):
super(VanillaRNN, self).__init__(trainable=trainable, **kwargs)

self.num_unit = num_unit
check_integer(num_unit, 'num_unit', min_bound=1, allow_none=False)
self.set_output_shape((None, self.num_unit))

# initializers
self._state_initializer = state_initializer
self._wi_initializer = wi_initializer
self._wh_initializer = wh_initializer
@@ -55,23 +63,23 @@ class VanillaRNN(RecurrentNode):
check_initializer(state_initializer, 'state_initializer', allow_none=False)
check_initializer(bias_initializer, 'bias_initializer', allow_none=True)

# activation function
self.activation = bm.activations.get(activation)

def init_ff(self):
def init_ff_conn(self):
unique_size, free_sizes = check_shape_consistency(self.feedforward_shapes, -1, True)
assert len(unique_size) == 1, 'Only support data with or without batch size.'
num_input = sum(free_sizes)
self.set_output_shape(unique_size + (self.num_unit,))
# weights
num_input = sum(free_sizes)
self.Wff = init_param(self._wi_initializer, (num_input, self.num_unit))
self.Wrec = init_param(self._wh_initializer, (self.num_unit, self.num_unit))
self.bff = init_param(self._bias_initializer, (self.num_unit,))
self.bias = init_param(self._bias_initializer, (self.num_unit,))
if self.trainable:
self.Wff = bm.TrainVar(self.Wff)
self.Wrec = bm.TrainVar(self.Wrec)
self.bff = None if (self.bff is None) else bm.TrainVar(self.bff)
self.bias = None if (self.bias is None) else bm.TrainVar(self.bias)

def init_fb(self):
def init_fb_conn(self):
unique_size, free_sizes = check_shape_consistency(self.feedback_shapes, -1, True)
assert len(unique_size) == 1, 'Only support data with or without batch size.'
num_feedback = sum(free_sizes)
@@ -80,16 +88,15 @@ class VanillaRNN(RecurrentNode):
if self.trainable:
self.Wfb = bm.TrainVar(self.Wfb)

def init_state(self, num_batch):
state = init_param(self._state_initializer, (num_batch, self.num_unit))
self.set_state(state)
def init_state(self, num_batch=1):
return init_param(self._state_initializer, (num_batch, self.num_unit))

def forward(self, ff, fb=None, **kwargs):
def forward(self, ff, fb=None, **shared_kwargs):
ff = bm.concatenate(ff, axis=-1)
h = ff @ self.Wff
h += self.state.value @ self.Wrec
if self.bff is not None:
h += self.bff
if self.bias is not None:
h += self.bias
if fb is not None:
fb = bm.concatenate(fb, axis=-1)
h += fb @ self.Wfb
@@ -98,8 +105,7 @@ class VanillaRNN(RecurrentNode):


class GRU(RecurrentNode):
r"""
Gated Recurrent Unit.
r"""Gated Recurrent Unit.

The implementation is based on (Chung, et al., 2014) [1]_ with biases.

@@ -130,17 +136,18 @@ class GRU(RecurrentNode):
def __init__(
self,
num_unit: int,
wi_initializer=Orthogonal(),
wh_initializer=Orthogonal(),
bias_initializer=ZeroInit(),
state_initializer=ZeroInit(),
trainable=True,
wi_initializer: Union[Tensor, Callable, Initializer] = Orthogonal(),
wh_initializer: Union[Tensor, Callable, Initializer] = Orthogonal(),
bias_initializer: Union[Tensor, Callable, Initializer] = ZeroInit(),
state_initializer: Union[Tensor, Callable, Initializer] = ZeroInit(),
trainable: bool = True,
**kwargs
):
super(GRU, self).__init__(trainable=trainable, **kwargs)

self.num_unit = num_unit
check_integer(num_unit, 'num_unit', min_bound=1, allow_none=False)
self.set_output_shape((None, self.num_unit))

self._wi_initializer = wi_initializer
self._wh_initializer = wh_initializer
@@ -151,30 +158,39 @@ class GRU(RecurrentNode):
check_initializer(state_initializer, 'state_initializer', allow_none=False)
check_initializer(bias_initializer, 'bias_initializer', allow_none=True)

def init_ff(self):
def init_ff_conn(self):
# data shape
unique_size, free_sizes = check_shape_consistency(self.feedforward_shapes, -1, True)
assert len(unique_size) == 1, 'Only support data with or without batch size.'
num_input = sum(free_sizes)
self.set_output_shape(unique_size + (self.num_unit,))

# weights
self.i_weight = init_param(self._wi_initializer, (num_input, self.num_unit * 3))
self.h_weight = init_param(self._wh_initializer, (self.num_unit, self.num_unit * 3))
num_input = sum(free_sizes)
self.Wi_ff = init_param(self._wi_initializer, (num_input, self.num_unit * 3))
self.Wh = init_param(self._wh_initializer, (self.num_unit, self.num_unit * 3))
self.bias = init_param(self._bias_initializer, (self.num_unit * 3,))
if self.trainable:
self.i_weight = bm.TrainVar(self.i_weight)
self.h_weight = bm.TrainVar(self.h_weight)
self.Wi_ff = bm.TrainVar(self.Wi_ff)
self.Wh = bm.TrainVar(self.Wh)
self.bias = bm.TrainVar(self.bias) if (self.bias is not None) else None

def init_state(self, num_batch):
state = init_param(self._state_initializer, (num_batch, self.num_unit))
self.set_state(state)
def init_fb_conn(self):
unique_size, free_sizes = check_shape_consistency(self.feedback_shapes, -1, True)
assert len(unique_size) == 1, 'Only support data with or without batch size.'
num_feedback = sum(free_sizes)
# weights
self.Wi_fb = init_param(self._wi_initializer, (num_feedback, self.num_unit * 3))
if self.trainable:
self.Wi_fb = bm.TrainVar(self.Wi_fb)

def forward(self, ff, fb=None, **kwargs):
ff = bm.concatenate(ff, axis=-1)
gates_x = bm.matmul(ff, self.i_weight)
def init_state(self, num_batch=1):
return init_param(self._state_initializer, (num_batch, self.num_unit))

def forward(self, ff, fb=None, **shared_kwargs):
gates_x = bm.matmul(bm.concatenate(ff, axis=-1), self.Wi_ff)
if fb is not None:
gates_x += bm.matmul(bm.concatenate(fb, axis=-1), self.Wi_fb)
zr_x, a_x = bm.split(gates_x, indices_or_sections=[2 * self.num_unit], axis=-1)
w_h_z, w_h_a = bm.split(self.h_weight, indices_or_sections=[2 * self.num_unit], axis=-1)
w_h_z, w_h_a = bm.split(self.Wh, indices_or_sections=[2 * self.num_unit], axis=-1)
zr_h = bm.matmul(self.state, w_h_z)
zr = zr_x + zr_h
has_bias = (self.bias is not None)
@@ -235,48 +251,62 @@ class LSTM(RecurrentNode):
def __init__(
self,
num_unit: int,
weight_initializer=Orthogonal(),
bias_initializer=ZeroInit(),
state_initializer=ZeroInit(),
trainable=True,
wi_initializer: Union[Tensor, Callable, Initializer] = XavierNormal(),
wh_initializer: Union[Tensor, Callable, Initializer] = XavierNormal(),
bias_initializer: Union[Tensor, Callable, Initializer] = ZeroInit(),
state_initializer: Union[Tensor, Callable, Initializer] = ZeroInit(),
trainable: bool = True,
**kwargs
):
super(LSTM, self).__init__(trainable=trainable, **kwargs)

self.num_unit = num_unit
check_integer(num_unit, 'num_unit', min_bound=1, allow_none=False)
self.set_output_shape((None, self.num_unit,))

self._state_initializer = state_initializer
self._weight_initializer = weight_initializer
self._wi_initializer = wi_initializer
self._wh_initializer = wh_initializer
self._bias_initializer = bias_initializer
check_initializer(weight_initializer, 'weight_initializer', allow_none=False)
check_initializer(wi_initializer, 'wi_initializer', allow_none=False)
check_initializer(wh_initializer, 'wh_initializer', allow_none=False)
check_initializer(bias_initializer, 'bias_initializer', allow_none=True)
check_initializer(state_initializer, 'state_initializer', allow_none=False)

def init_ff(self):
def init_ff_conn(self):
# data shape
unique_size, free_sizes = check_shape_consistency(self.feedforward_shapes, -1, True)
assert len(unique_size) == 1, 'Only support data with or without batch size.'
num_input = sum(free_sizes)
self.set_output_shape(unique_size + (self.num_unit,))
# weights
self.weight = init_param(self._weight_initializer, (num_input + self.num_unit, self.num_unit * 4))
num_input = sum(free_sizes)
self.Wi_ff = init_param(self._wi_initializer, (num_input, self.num_unit * 4))
self.Wh = init_param(self._wh_initializer, (self.num_unit, self.num_unit * 4))
self.bias = init_param(self._bias_initializer, (self.num_unit * 4,))
if self.trainable:
self.weight = bm.TrainVar(self.weight)
self.Wi_ff = bm.TrainVar(self.Wi_ff)
self.Wh = bm.TrainVar(self.Wh)
self.bias = None if (self.bias is None) else bm.TrainVar(self.bias)

def init_state(self, num_batch):
hc = init_param(self._state_initializer, (num_batch * 2, self.num_unit))
self.set_state(hc)
def init_fb_conn(self):
unique_size, free_sizes = check_shape_consistency(self.feedback_shapes, -1, True)
assert len(unique_size) == 1, 'Only support data with or without batch size.'
num_feedback = sum(free_sizes)
# weights
self.Wi_fb = init_param(self._wi_initializer, (num_feedback, self.num_unit * 4))
if self.trainable:
self.Wi_fb = bm.TrainVar(self.Wi_fb)

def init_state(self, num_batch=1):
return init_param(self._state_initializer, (num_batch * 2, self.num_unit))

def forward(self, ff, fb=None, **kwargs):
def forward(self, ff, fb=None, **shared_kwargs):
h, c = bm.split(self.state, 2)
xh = bm.concatenate(tuple(ff) + (h,), axis=-1)
if self.bias is None:
gated = xh @ self.weight
else:
gated = xh @ self.weight + self.bias
gated = bm.concatenate(ff, axis=-1) @ self.Wi_ff
if fb is not None:
gated += bm.concatenate(fb, axis=-1) @ self.Wi_fb
if self.bias is not None:
gated += self.bias
gated += h @ self.Wh
i, g, f, o = bm.split(gated, indices_or_sections=4, axis=-1)
c = bm.sigmoid(f + 1.) * c + bm.sigmoid(i) * bm.tanh(g)
h = bm.sigmoid(o) * bm.tanh(c)
@@ -291,7 +321,7 @@ class LSTM(RecurrentNode):
@h.setter
def h(self, value):
if self.state is None:
raise ValueError('Cannot set "h" state. Because it is not initialized.')
raise ValueError('Cannot set "h" state. Because the state is not initialized.')
self.state[:self.state.shape[0] // 2, :] = value

@property
@@ -302,7 +332,7 @@ class LSTM(RecurrentNode):
@c.setter
def c(self, value):
if self.state is None:
raise ValueError('Cannot set "c" state. Because it is not initialized.')
raise ValueError('Cannot set "c" state. Because the state is not initialized.')
self.state[self.state.shape[0] // 2:, :] = value




+ 6
- 6
brainpy/nn/nodes/RC/linear_readout.py View File

@@ -39,12 +39,12 @@ class LinearReadout(Dense):
super(LinearReadout, self).__init__(num_unit=num_unit, weight_initializer=weight_initializer, bias_initializer=bias_initializer, **kwargs)

def init_state(self, num_batch=1):
state = bm.Variable(bm.zeros((num_batch,) + self.output_shape[1:], dtype=bm.float_))
self.set_state(state)
return bm.zeros((num_batch,) + self.output_shape[1:], dtype=bm.float_)

def forward(self, ff, fb=None, **kwargs):
self.state.value = super(LinearReadout, self).forward(ff, fb=fb, **kwargs)
return self.state
def forward(self, ff, fb=None, **shared_kwargs):
h = super(LinearReadout, self).forward(ff, fb=fb, **shared_kwargs)
self.state.value = h
return h

def __force_init__(self, train_pars: Optional[Dict] = None):
if train_pars is None: train_pars = dict()
@@ -76,4 +76,4 @@ class LinearReadout(Dense):
# update the weights
e = bm.atleast_2d(self.state - target) # (1, num_output)
dw = bm.dot(-c * k, e) # (num_hidden, num_output)
self.weights += dw
self.Wff += dw

+ 44
- 32
brainpy/nn/nodes/RC/nvar.py View File

@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-

from itertools import combinations_with_replacement
from typing import Union
from typing import Union, Sequence

import numpy as np

@@ -9,7 +9,8 @@ import brainpy.math as bm
from brainpy.nn.base import RecurrentNode
from brainpy.tools.checking import (check_shape_consistency,
check_float,
check_integer)
check_integer,
check_sequence)

__all__ = [
'NVAR'
@@ -46,7 +47,7 @@ class NVAR(RecurrentNode):
----------
delay: int
The number of delay step.
order: int
order: int, sequence of int
The nonlinear order.
stride: int
The stride to sample linear part vector in the delays.
@@ -63,59 +64,67 @@ class NVAR(RecurrentNode):

def __init__(self,
delay: int,
order: int,
order: Union[int, Sequence[int]],
stride: int = 1,
constant: Union[float, int] = None,
**kwargs):
super(NVAR, self).__init__(**kwargs)

self.delay = delay
if not isinstance(order, (tuple, list)):
order = [order]
self.order = order
self.stride = stride
self.constant = constant
check_sequence(order, 'order', elem_type=int, allow_none=False)
self.delay = delay
check_integer(delay, 'delay', allow_none=False)
check_integer(order, 'order', allow_none=False)
self.stride = stride
check_integer(stride, 'stride', allow_none=False)
self.constant = constant
check_float(constant, 'constant', allow_none=True, allow_int=True)

def init_ff(self):
self.comb_ids = []
# delay variables
self.num_delay = self.delay * self.stride
self.idx = bm.Variable(bm.array([0], dtype=bm.uint32))
self.store = None

def init_ff_conn(self):
"""Initialize feedforward connections."""
# input dimension
batch_size, free_size = check_shape_consistency(self.feedforward_shapes, -1, True)
self.input_dim = sum(free_size)
assert batch_size == (None,), f'batch_size must be None, but got {batch_size}'

# linear dimension
linear_dim = self.delay * self.input_dim
# for each monomial created in the non linear part, indices
# for each monomial created in the non-linear part, indices
# of the n components involved, n being the order of the
# monomials. Precompute them to improve efficiency.
idx = np.array(list(combinations_with_replacement(np.arange(linear_dim), self.order)))
self.comb_ids = bm.asarray(idx)
# number of non linear components is (d + n - 1)! / (d - 1)! n!
for order in self.order:
idx = np.array(list(combinations_with_replacement(np.arange(linear_dim), order)))
self.comb_ids.append(bm.asarray(idx))
# number of non-linear components is (d + n - 1)! / (d - 1)! n!
# i.e. number of all unique monomials of order n made from the
# linear components.
nonlinear_dim = len(self.comb_ids)
nonlinear_dim = sum([len(ids) for ids in self.comb_ids])
# output dimension
output_dim = int(linear_dim + nonlinear_dim)
self.output_dim = int(linear_dim + nonlinear_dim)
if self.constant is not None:
output_dim += 1
self.set_output_shape((None, output_dim))

# delay variables
self.num_delay = self.delay * self.stride
self.idx = bm.Variable(bm.array([0], dtype=bm.uint32))
self.store = None
self.output_dim += 1
self.set_output_shape((None, self.output_dim))

def init_state(self, num_batch=1):
# to store the k*s last inputs, k being the delay and s the strides
"""Initialize the node state which depends on batch size."""
# To store the last inputs.
# Note, the batch axis is not in the first dimension, so we
# manually handle the state of NVAR, rather return it.
state = bm.zeros((self.num_delay, num_batch, self.input_dim), dtype=bm.float_)
if self.store is None:
self.store = bm.Variable(state)
else:
self.store.value = state

def forward(self, ff, fb=None, **kwargs):
# 1. store the current input
def forward(self, ff, fb=None, **shared_kwargs):
all_parts = []
# 1. Store the current input
ff = bm.concatenate(ff, axis=-1)
self.store[self.idx[0]] = ff
self.idx.value = (self.idx + 1) % self.num_delay
@@ -124,12 +133,15 @@ class NVAR(RecurrentNode):
select_ids = (self.idx[0] + bm.arange(self.num_delay)[::self.stride]) % self.num_delay
linear_parts = bm.moveaxis(self.store[select_ids], 0, 1) # (num_batch, num_time, num_feature)
linear_parts = bm.reshape(linear_parts, (linear_parts.shape[0], -1))
# 3. constant
if self.constant is not None:
constant = bm.broadcast_to(self.constant, linear_parts.shape[:-1] + (1,))
all_parts.append(constant)
all_parts.append(linear_parts)
# 3. Nonlinear part:
# select monomial terms and compute them
nonlinear_parts = bm.prod(linear_parts[:, self.comb_ids], axis=2)
if self.constant is None:
return bm.concatenate([linear_parts, nonlinear_parts], axis=-1)
else:
constant = bm.broadcast_to(self.constant, linear_parts.shape[:-1] + (1,))
return bm.concatenate([constant, linear_parts, nonlinear_parts], axis=-1)
for ids in self.comb_ids:
all_parts.append(bm.prod(linear_parts[:, ids], axis=2))
# 4. Return all parts
return bm.concatenate(all_parts, axis=-1)


+ 5
- 7
brainpy/nn/nodes/RC/reservoir.py View File

@@ -3,9 +3,8 @@
from typing import Optional, Union, Callable

import brainpy.math as bm
from brainpy.initialize import Normal, ZeroInit, Initializer
from brainpy.initialize import Normal, ZeroInit, Initializer, init_param
from brainpy.nn.base import RecurrentNode
from brainpy.nn.utils import init_param
from brainpy.tools.checking import (check_shape_consistency,
check_float,
check_initializer,
@@ -158,7 +157,7 @@ class Reservoir(RecurrentNode):
self.noise_type = noise_type
check_string(noise_type, 'noise_type', ['normal', 'uniform'])

def init_ff(self):
def init_ff_conn(self):
"""Initialize feedforward connections, weights, and variables."""
unique_shape, free_shapes = check_shape_consistency(self.feedforward_shapes, -1, True)
self.set_output_shape(unique_shape + (self.num_unit,))
@@ -197,10 +196,9 @@ class Reservoir(RecurrentNode):

def init_state(self, num_batch=1):
# initialize internal state
state = bm.Variable(bm.zeros((num_batch, self.num_unit), dtype=bm.float_))
self.set_state(state)
return bm.zeros((num_batch, self.num_unit), dtype=bm.float_)

def init_fb(self):
def init_fb_conn(self):
"""Initialize feedback connections, weights, and variables."""
if self.feedback_shapes is not None:
unique_shape, free_shapes = check_shape_consistency(self.feedback_shapes, -1, True)
@@ -215,7 +213,7 @@ class Reservoir(RecurrentNode):
if self.trainable:
self.Wfb = bm.TrainVar(self.Wfb)

def forward(self, ff, fb=None, **kwargs):
def forward(self, ff, fb=None, **shared_kwargs):
"""Feedforward output."""
# inputs
x = bm.concatenate(ff, axis=-1)


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save