#9 Version 1.0.0-alpha

Merged
BrainPy merged 15 commits from develop into master 3 years ago
  1. +4
    -5
      .gitignore
  2. +26
    -16
      README.md
  3. +14
    -26
      brainpy/__init__.py
  4. +5
    -4
      brainpy/analysis/__init__.py
  5. +84
    -81
      brainpy/analysis/base.py
  6. +121
    -164
      brainpy/analysis/bifurcation.py
  7. +121
    -119
      brainpy/analysis/phase_plane.py
  8. +18
    -9
      brainpy/analysis/solver.py
  9. +161
    -0
      brainpy/analysis/stability.py
  10. +109
    -0
      brainpy/analysis/trajectory.py
  11. +102
    -127
      brainpy/analysis/utils.py
  12. +221
    -2
      brainpy/backend/__init__.py
  13. +1
    -0
      brainpy/backend/operators/__init__.py
  14. +33
    -0
      brainpy/backend/operators/bk_jax.py
  15. +23
    -0
      brainpy/backend/operators/bk_numba_cpu.py
  16. +2
    -0
      brainpy/backend/operators/bk_numba_cuda.py
  17. +18
    -6
      brainpy/backend/operators/bk_numba_overload.py
  18. +33
    -0
      brainpy/backend/operators/bk_numpy.py
  19. +28
    -0
      brainpy/backend/operators/bk_pytorch.py
  20. +31
    -0
      brainpy/backend/operators/bk_tensorflow.py
  21. +339
    -0
      brainpy/backend/operators/standard.py
  22. +1
    -0
      brainpy/backend/runners/__init__.py
  23. +258
    -0
      brainpy/backend/runners/general_runner.py
  24. +5
    -0
      brainpy/backend/runners/jax_runner.py
  25. +583
    -0
      brainpy/backend/runners/numba_cpu_runner.py
  26. +158
    -0
      brainpy/backend/runners/numba_cuda_runner.py
  27. +57
    -0
      brainpy/backend/runners/utils.py
  28. +0
    -44
      brainpy/backend/utils.py
  29. +98
    -85
      brainpy/connectivity/base.py
  30. +324
    -334
      brainpy/connectivity/methods.py
  31. +0
    -8
      brainpy/core/__init__.py
  32. +0
    -609
      brainpy/core/base.py
  33. +0
    -28
      brainpy/core/constants.py
  34. +0
    -331
      brainpy/core/network.py
  35. +0
    -181
      brainpy/core/neurons.py
  36. +0
    -1256
      brainpy/core/runner.py
  37. +0
    -244
      brainpy/core/synapses.py
  38. +0
    -442
      brainpy/core/types.py
  39. +0
    -151
      brainpy/core/utils.py
  40. +4
    -7
      brainpy/errors.py
  41. +4
    -268
      brainpy/inputs.py
  42. +0
    -74
      brainpy/integration/__init__.py
  43. +0
    -24
      brainpy/integration/constants.py
  44. +0
    -332
      brainpy/integration/diff_equation.py
  45. +0
    -1128
      brainpy/integration/integrator.py
  46. +10
    -0
      brainpy/integrators/__init__.py
  47. +220
    -0
      brainpy/integrators/ast_analysis.py
  48. +95
    -0
      brainpy/integrators/constants.py
  49. +1
    -0
      brainpy/integrators/dde/__init__.py
  50. +73
    -0
      brainpy/integrators/delay_vars.py
  51. +1
    -0
      brainpy/integrators/fde/__init__.py
  52. +120
    -0
      brainpy/integrators/integrate_wrapper.py
  53. +10
    -0
      brainpy/integrators/ode/__init__.py
  54. +18
    -0
      brainpy/integrators/ode/exp_euler.py
  55. +340
    -0
      brainpy/integrators/ode/rk_adaptive_methods.py
  56. +347
    -0
      brainpy/integrators/ode/rk_methods.py
  57. +322
    -0
      brainpy/integrators/ode/wrapper.py
  58. +11
    -0
      brainpy/integrators/sde/__init__.py
  59. +53
    -0
      brainpy/integrators/sde/common.py
  60. +276
    -0
      brainpy/integrators/sde/euler_and_milstein.py
  61. +218
    -0
      brainpy/integrators/sde/exp_euler.py
  62. +442
    -0
      brainpy/integrators/sde/srk_scalar.py
  63. +442
    -0
      brainpy/integrators/sde/srk_strong.py
  64. +259
    -203
      brainpy/integrators/sympy_analysis.py
  65. +111
    -0
      brainpy/integrators/utils.py
  66. +11
    -4
      brainpy/measure.py
  67. +0
    -377
      brainpy/profile.py
  68. +10
    -0
      brainpy/simulation/__init__.py
  69. +266
    -0
      brainpy/simulation/brain_objects.py
  70. +16
    -0
      brainpy/simulation/constants.py
  71. +62
    -0
      brainpy/simulation/delay.py
  72. +181
    -0
      brainpy/simulation/dynamic_system.py
  73. +53
    -0
      brainpy/simulation/monitors.py
  74. +66
    -0
      brainpy/simulation/runner.py
  75. +240
    -0
      brainpy/simulation/utils.py
  76. +0
    -1
      brainpy/tools/__init__.py
  77. +3
    -2
      brainpy/tools/ast2code.py
  78. +68
    -516
      brainpy/tools/codes.py
  79. +2
    -1
      brainpy/tools/dicts.py
  80. +0
    -125
      brainpy/tools/functions.py
  81. +7
    -7
      brainpy/visualization/figures.py
  82. +4
    -4
      brainpy/visualization/plots.py
  83. +7
    -7
      develop/benchmark/COBA/COBA.py
  84. +86
    -93
      develop/benchmark/COBA/COBA_brainpy.py
  85. +5
    -5
      develop/benchmark/COBAHH/COBAHH_brainpy.py
  86. +9
    -9
      develop/benchmark/CUBA/CUBA_brainpy.py
  87. +23
    -20
      develop/benchmark/scaling_test.py
  88. +1
    -2
      develop/conda-recipe/meta.yaml
  89. +1
    -1
      docs/Makefile
  90. +0
    -291
      docs/advanced/HH_model_in_ANNarchy.ipynb
  91. +0
    -0
      docs/advanced/Limitations.rst
  92. +0
    -393
      docs/advanced/debugging.ipynb
  93. +0
    -232
      docs/advanced/differential_equations.ipynb
  94. +0
    -266
      docs/advanced/gapjunction_lif_in_brian2.ipynb
  95. +4
    -4
      docs/apis/analysis.rst
  96. +18
    -0
      docs/apis/backend.rst
  97. +2
    -2
      docs/apis/connectivity.rst
  98. +0
    -43
      docs/apis/core.rst
  99. +5
    -5
      docs/apis/errors.rst
  100. +2
    -19
      docs/apis/inputs.rst

+ 4
- 5
.gitignore View File

@@ -3,7 +3,6 @@ publishment.md
TODO.md
.vscode

examples/
develop/benchmark/COBA/results
develop/outputdir

@@ -17,14 +16,14 @@ docs/images/connection_methods.pptx
docs/tutorials/_autosummary
docs/tutorials/.ipynb_checkpoints

docs/advanced/_autosummary
docs/advanced/.ipynb_checkpoints
docs/tutorials_advanced/_autosummary
docs/tutorials_advanced/.ipynb_checkpoints
develop/fast_synapse_computation.py
develop/fast_synapse_computation2.py

docs/apis/_autosummary
docs/advanced/usage_of_inputfactory.py
docs/advanced/usage_of_utils_connect.py
docs/tutorials_advanced/usage_of_inputfactory.py
docs/tutorials_advanced/usage_of_utils_connect.py

develop/benchmark/COBA/brian2*
develop/benchmark/COBA/annarchy*


+ 26
- 16
README.md View File

@@ -1,7 +1,7 @@

![Logo](docs/images/logo.png)

[![LICENSE](https://anaconda.org/brainpy/brainpy/badges/license.svg)](https://github.com/PKU-NIP-Lab/BrainPy) [![Documentation](https://readthedocs.org/projects/brainpy/badge/?version=latest)](https://brainpy.readthedocs.io/en/latest/?badge=latest) [![Conda](https://anaconda.org/brainpy/brainpy-simulator/badges/version.svg)](https://anaconda.org/brainpy/brainpy-simulator) [![PyPI version](https://badge.fury.io/py/brainpy-simulator.svg)](https://badge.fury.io/py/brainpy-simulator) [![travis](https://travis-ci.org/PKU-NIP-Lab/BrainPy.svg?branch=master)](https://travis-ci.org/PKU-NIP-Lab/BrainPy)
[![LICENSE](https://anaconda.org/brainpy/brainpy/badges/license.svg)](https://github.com/PKU-NIP-Lab/BrainPy) [![Documentation](https://readthedocs.org/projects/brainpy/badge/?version=latest)](https://brainpy.readthedocs.io/en/latest/?badge=latest) [![Conda](https://anaconda.org/brainpy/brainpy-simulator/badges/version.svg)](https://anaconda.org/brainpy/brainpy-simulator) [![PyPI version](https://badge.fury.io/py/brainpy-simulator.svg)](https://badge.fury.io/py/brainpy-simulator)



@@ -11,12 +11,18 @@

## Why to use BrainPy

``BrainPy`` is a lightweight framework based on the latest Just-In-Time (JIT) compilers (especially [Numba](https://numba.pydata.org/)). The goal of ``BrainPy`` is to provide a unified simulation and analysis framework for neuronal dynamics with the feature of high flexibility and efficiency. BrainPy is flexible because it endows the users with the fully data/logic flow control. BrainPy is efficient because it supports JIT acceleration on CPUs and GPUs.
``BrainPy`` is an integrative framework for computational neuroscience and brain-inspired computation. Three core functions are provided in `BrainPy`:

- *General numerical solvers* for ODEs and SDEs (future will support DDEs and FDEs).
- *Neurodynamics simulation tools* for brain objects, such like neurons, synapses and networks (future will support soma and dendrites).
- *Neurodynamics analysis tools* for differential equations, including phase plane analysis and bifurcation analysis (future will support continuation analysis and sensitive analysis).

![Speed Comparison with Brian2](docs/images/speed.png)
Moreover, `BrainPy` can effectively satisfy your basic requirements: 1. *Easy to learn and use*, because it is only based on Python language and has little dependency requirements; 2. *Highly flexible and transparent*, because it endows the users with the fully data/logic flow control; 3. *Simulation can be guided with the analysis*, because the same code in BrainPy can not only be used for simulation, but also for dynamics analysis; 4. *Efficient running speed*, because BrainPy is compatitable with the latest JIT compilers (or any other computing backend you prefer).

![Scaling of BrainPy](docs/images/speed_scaling.png)



![Speed Comparison](docs/images/speed.png)



@@ -24,30 +30,34 @@

Install ``BrainPy`` using ``pip``:

> pip install brainpy-simulator
```bash
> pip install brainpy-simulator
```

Install ``BrainPy`` using ``conda``:

> conda install brainpy-simulator -c brainpy
```bash
> conda install brainpy-simulator -c brainpy
```

Install ``BrainPy`` from source:

> pip install git+https://github.com/PKU-NIP-Lab/BrainPy
> # or
> pip install git+https://git.openi.org.cn/OpenI/BrainPy
> # or
> pip install -e git://github.com/PKU-NIP-Lab/BrainPy.git@V0.2.5
```bash
> pip install git+https://github.com/PKU-NIP-Lab/BrainPy
> # or
> pip install git+https://git.openi.org.cn/OpenI/BrainPy
> # or
> pip install -e git://github.com/PKU-NIP-Lab/BrainPy.git@V0.2.5
```

``BrainPy`` is based on Python (>=3.7), and the following packages are
required to be installed to use ``BrainPy``:
``BrainPy`` is based on Python (>=3.7), and the following packages are required to be installed to use ``BrainPy``:

- NumPy >= 1.13
- SymPy >= 1.2
- SciPy >= 1.2
- Numba >= 0.50.0
- Matplotlib >= 3.0




## Neurodynamics simulation

<table border="0">


+ 14
- 26
brainpy/__init__.py View File

@@ -1,9 +1,9 @@
# -*- coding: utf-8 -*-

__version__ = "0.3.6"
__version__ = "1.0.0-alpha"

# "profile" module
from . import profile
# "analysis" module
from . import analysis

# "backend" module
from . import backend
@@ -12,35 +12,23 @@ from . import backend
from . import connectivity
from . import connectivity as connect

# "core" module
from . import core as core
from .core.base import ObjType
from .core.base import Ensemble
from .core.neurons import NeuType
from .core.neurons import NeuGroup
from .core.synapses import SynType
from .core.synapses import SynConn
from .core.synapses import delayed
from .core.network import Network
from .core import types
from .core.types import ObjState
from .core.types import NeuState
from .core.types import SynState

# "integration" module
from . import integration
from .integration import integrate
# "simulation" module
from . import simulation
from .simulation.dynamic_system import *
from .simulation.brain_objects import *

# "analysis" module
from . import analysis

# "tools" module
from . import tools
# "integrators" module
from . import integrators
from .integrators import ode
from .integrators import sde
from .integrators.integrate_wrapper import *
from .integrators.constants import *

# "visualization" module
from . import visualization as visualize

# other modules
from . import tools
from . import inputs
from . import measure
from . import running

+ 5
- 4
brainpy/analysis/__init__.py View File

@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-

from .solver import *
from .utils import *
from .base import *
from .phase_plane import *
from .bifurcation import *

from .phase_plane import *
from .solver import *
from .stability import *
from .trajectory import *
from .utils import *

+ 84
- 81
brainpy/analysis/base.py View File

@@ -6,12 +6,11 @@ from copy import deepcopy
import numpy as np
import sympy

from . import solver
from . import utils
from .. import core
from .. import errors
from .. import integration
from .. import tools
from brainpy import errors
from brainpy import tools
from brainpy.analysis import solver
from brainpy.analysis import utils
from brainpy.integrators import sympy_analysis

__all__ = [
'BaseNeuronAnalyzer',
@@ -35,8 +34,9 @@ class BaseNeuronAnalyzer(object):

Parameters
----------
model : core.NeuType
The neuronal type model.
model_or_integrals : simulation.Population, function, functions
A model of the population, the integrator function,
or a list/tuple of integrator functions.
target_vars : dict
The target/dynamical variables.
fixed_vars : dict
@@ -71,7 +71,7 @@ class BaseNeuronAnalyzer(object):
"""

def __init__(self,
model,
model_or_integrals,
target_vars,
fixed_vars=None,
target_pars=None,
@@ -80,11 +80,14 @@ class BaseNeuronAnalyzer(object):
options=None):

# model
# ------
if not isinstance(model, core.NeuType):
raise errors.ModelUseError(f'Neuron Dynamics Analyzer now only support NeuType, '
f'but get {type(model)}.')
self.model = model
# -----
if isinstance(model_or_integrals, utils.DynamicModel):
self.model = model_or_integrals
elif (isinstance(model_or_integrals, (tuple, list)) and callable(model_or_integrals[0])) or \
callable(model_or_integrals):
self.model = utils.transform_integrals_to_model(model_or_integrals)
else:
raise ValueError

# target variables
# ----------------
@@ -96,43 +99,40 @@ class BaseNeuronAnalyzer(object):
self.dvar_names = list(self.target_vars.keys())
else:
self.dvar_names = list(sorted(self.target_vars.keys()))
for key in self.target_vars.keys():
if key not in self.model.variables:
raise ValueError(f'{key} is not a dynamical variable in {self.model}.')

# fixed variables
# ----------------
if fixed_vars is None:
fixed_vars = dict()
if not isinstance(fixed_vars, dict):
raise errors.ModelUseError('"fixed_vars" must be a dict with the format '
'of {"var1": val1, "var2": val2}.')
self.fixed_vars = dict()
for integrator in model.integrators:
var_name = integrator.diff_eq.var_name
if var_name not in target_vars:
if var_name in fixed_vars:
self.fixed_vars[var_name] = fixed_vars.get(var_name)
else:
self.fixed_vars[var_name] = model.variables.get(var_name)
for key in fixed_vars.keys():
if key not in self.fixed_vars:
self.fixed_vars[key] = fixed_vars.get(key)
if key not in self.model.variables:
raise ValueError(f'{key} is not a dynamical variable in {self.model}.')
self.fixed_vars = fixed_vars

# check duplicate
for key in self.fixed_vars.keys():
if key in self.target_vars:
raise errors.ModelUseError(f'"{key}" is defined as a target variable in "target_vars", '
f'but also defined as a fixed variable in "fixed_vars".')

# equations of dynamical variables
# --------------------------------
var2eq = {integrator.diff_eq.var_name: integrator for integrator in model.integrators}
target_func_args = set()
var2eq = {ana.var_name: ana for ana in self.model.analyzers}
self.target_eqs = tools.DictPlus()
for key in self.target_vars.keys():
if key not in var2eq:
raise errors.ModelUseError(f'target "{key}" is not a dynamical variable.')
integrator = var2eq[key]
diff_eq = integrator.diff_eq
diff_eq = var2eq[key]
sub_exprs = diff_eq.get_f_expressions(substitute_vars=list(self.target_vars.keys()))
old_exprs = diff_eq.get_f_expressions(substitute_vars=None)
self.target_eqs[key] = tools.DictPlus(sub_exprs=sub_exprs,
old_exprs=old_exprs,
diff_eq=diff_eq,
func_name=diff_eq.func_name)
target_func_args.update(diff_eq.func_args)

# parameters to update
# ---------------------
@@ -142,9 +142,8 @@ class BaseNeuronAnalyzer(object):
raise errors.ModelUseError('"pars_update" must be a dict with the format '
'of {"par1": val1, "par2": val2}.')
for key in pars_update.keys():
if key not in model.step_scopes:
if key not in target_func_args:
raise errors.ModelUseError(f'"{key}" is not a valid parameter in "{model.name}" model.')
if (key not in self.model.scopes) and (key not in self.model.parameters):
raise errors.ModelUseError(f'"{key}" is not a valid parameter in "{self.model}" model.')
self.pars_update = pars_update

# dynamical parameters
@@ -152,18 +151,22 @@ class BaseNeuronAnalyzer(object):
if target_pars is None:
target_pars = dict()
if not isinstance(target_pars, dict):
raise errors.ModelUseError('"pars_dynamical" must be a dict with the format '
'of {"par1": (val1, val2)}.')
raise errors.ModelUseError('"target_pars" must be a dict with the format of {"par1": (val1, val2)}.')
for key in target_pars.keys():
if key not in model.step_scopes:
if key not in target_func_args:
raise errors.ModelUseError(f'"{key}" is not a valid parameter in "{model.name}" model.')
if (key not in self.model.scopes) and (key not in self.model.parameters):
raise errors.ModelUseError(f'"{key}" is not a valid parameter in "{self.model}" model.')
self.target_pars = target_pars
if isinstance(self.target_vars, OrderedDict):
self.dpar_names = list(self.target_pars.keys())
else:
self.dpar_names = list(sorted(self.target_pars.keys()))

# check duplicate
for key in self.pars_update.keys():
if key in self.target_pars:
raise errors.ModelUseError(f'"{key}" is defined as a target parameter in "target_pars", '
f'but also defined as a fixed parameter in "pars_update".')

# resolutions for numerical methods
# ---------------------------------
self.resolutions = dict()
@@ -242,7 +245,7 @@ class Base1DNeuronAnalyzer(BaseNeuronAnalyzer):
if 'dxdt' not in self.analyzed_results:
scope = deepcopy(self.pars_update)
scope.update(self.fixed_vars)
scope.update(integration.get_mapping_scope())
scope.update(sympy_analysis.get_mapping_scope())
scope.update(self.x_eq_group.diff_eq.func_scope)
argument = ', '.join(self.dvar_names + self.dpar_names)
func_code = f'def func({argument}):\n'
@@ -260,11 +263,11 @@ class Base1DNeuronAnalyzer(BaseNeuronAnalyzer):
x_var = self.dvar_names[0]
x_symbol = sympy.Symbol(x_var, real=True)
x_eq = self.x_eq_group.sub_exprs[-1].code
x_eq = integration.str2sympy(x_eq)
x_eq = sympy_analysis.str2sympy(x_eq)

eq_x_scope = deepcopy(self.pars_update)
eq_x_scope.update(self.fixed_vars)
eq_x_scope.update(integration.get_mapping_scope())
eq_x_scope.update(sympy_analysis.get_mapping_scope())
eq_x_scope.update(self.x_eq_group['diff_eq'].func_scope)

argument = ', '.join(self.dvar_names + self.dpar_names)
@@ -282,7 +285,7 @@ class Base1DNeuronAnalyzer(BaseNeuronAnalyzer):
# check
all_vars = set(eq_x_scope.keys())
all_vars.update(self.dvar_names + self.dpar_names)
if utils.contain_unknown_symbol(integration.sympy2str(dfxdx_expr), all_vars):
if utils.contain_unknown_symbol(sympy_analysis.sympy2str(dfxdx_expr), all_vars):
print('failed because contain unknown symbols.')
sympy_failed = True
else:
@@ -290,7 +293,7 @@ class Base1DNeuronAnalyzer(BaseNeuronAnalyzer):
func_codes = [f'def dfdx({argument}):']
for expr in self.x_eq_group.sub_exprs[:-1]:
func_codes.append(f'{expr.var_name} = {expr.code}')
func_codes.append(f'return {integration.sympy2str(dfxdx_expr)}')
func_codes.append(f'return {sympy_analysis.sympy2str(dfxdx_expr)}')
exec(compile('\n '.join(func_codes), '', 'exec'), eq_x_scope)
dfdx = eq_x_scope['dfdx']
sympy_failed = False
@@ -320,13 +323,13 @@ class Base1DNeuronAnalyzer(BaseNeuronAnalyzer):
"""

if 'fixed_point' not in self.analyzed_results:
x_eq = integration.str2sympy(self.x_eq_group.sub_exprs[-1].code)
x_eq = sympy_analysis.str2sympy(self.x_eq_group.sub_exprs[-1].code)

scope = deepcopy(self.pars_update)
scope.update(self.fixed_vars)
scope.update(integration.get_mapping_scope())
scope.update(sympy_analysis.get_mapping_scope())
scope.update(self.x_eq_group.diff_eq.func_scope)
scope['numpy'] = np
scope['np'] = np

timeout_len = self.options.sympy_solver_timeout
argument1 = ', '.join(self.dvar_names + self.dpar_names)
@@ -345,7 +348,7 @@ class Base1DNeuronAnalyzer(BaseNeuronAnalyzer):
for res in results:
all_vars = set(scope.keys())
all_vars.update(self.dvar_names + self.dpar_names)
if utils.contain_unknown_symbol(integration.sympy2str(res), all_vars):
if utils.contain_unknown_symbol(sympy_analysis.sympy2str(res), all_vars):
print('failed because contain unknown symbols.')
sympy_failed = True
break
@@ -355,7 +358,8 @@ class Base1DNeuronAnalyzer(BaseNeuronAnalyzer):
func_codes = [f'def solve_x({argument2}):']
for expr in self.x_eq_group.sub_exprs[:-1]:
func_codes.append(f'{expr.var_name} = {expr.code}')
result_expr = ', '.join([integration.sympy2str(expr) for expr in results])
result_expr = ', '.join([sympy_analysis.sympy2str(expr)
for expr in results])
func_codes.append(f'_res_ = {result_expr}')
func_codes.append(f'return np.array(_res_)')

@@ -454,7 +458,7 @@ class Base2DNeuronAnalyzer(Base1DNeuronAnalyzer):
# check "f"
scope = deepcopy(self.pars_update)
scope.update(self.fixed_vars)
scope.update(integration.get_mapping_scope())
scope.update(sympy_analysis.get_mapping_scope())
if a.endswith('y_eq'):
scope.update(self.y_eq_group['diff_eq'].func_scope)
else:
@@ -485,7 +489,7 @@ class Base2DNeuronAnalyzer(Base1DNeuronAnalyzer):
y_var = self.dvar_names[1]
scope = deepcopy(self.pars_update)
scope.update(self.fixed_vars)
scope.update(integration.get_mapping_scope())
scope.update(sympy_analysis.get_mapping_scope())
scope.update(self.y_eq_group.diff_eq.func_scope)
argument = ', '.join(self.dvar_names + self.dpar_names)
func_code = f'def func({argument}):\n'
@@ -503,11 +507,11 @@ class Base2DNeuronAnalyzer(Base1DNeuronAnalyzer):
y_var = self.dvar_names[1]
y_symbol = sympy.Symbol(y_var, real=True)
x_eq = self.target_eqs[x_var].sub_exprs[-1].code
x_eq = integration.str2sympy(x_eq)
x_eq = sympy_analysis.str2sympy(x_eq)

eq_x_scope = deepcopy(self.pars_update)
eq_x_scope.update(self.fixed_vars)
eq_x_scope.update(integration.get_mapping_scope())
eq_x_scope.update(sympy_analysis.get_mapping_scope())
eq_x_scope.update(self.x_eq_group['diff_eq'].func_scope)

argument = ', '.join(self.dvar_names + self.dpar_names)
@@ -525,7 +529,7 @@ class Base2DNeuronAnalyzer(Base1DNeuronAnalyzer):
# check
all_vars = set(eq_x_scope.keys())
all_vars.update(self.dvar_names + self.dpar_names)
if utils.contain_unknown_symbol(integration.sympy2str(dfxdy_expr), all_vars):
if utils.contain_unknown_symbol(sympy_analysis.sympy2str(dfxdy_expr), all_vars):
print('failed because contain unknown symbols.')
sympy_failed = True
else:
@@ -533,7 +537,7 @@ class Base2DNeuronAnalyzer(Base1DNeuronAnalyzer):
func_codes = [f'def dfdy({argument}):']
for expr in self.x_eq_group.sub_exprs[:-1]:
func_codes.append(f'{expr.var_name} = {expr.code}')
func_codes.append(f'return {integration.sympy2str(dfxdy_expr)}')
func_codes.append(f'return {sympy_analysis.sympy2str(dfxdy_expr)}')
exec(compile('\n '.join(func_codes), '', 'exec'), eq_x_scope)
dfdy = eq_x_scope['dfdy']
sympy_failed = False
@@ -566,11 +570,11 @@ class Base2DNeuronAnalyzer(Base1DNeuronAnalyzer):
x_symbol = sympy.Symbol(x_var, real=True)
y_var = self.dvar_names[1]
y_eq = self.target_eqs[y_var].sub_exprs[-1].code
y_eq = integration.str2sympy(y_eq)
y_eq = sympy_analysis.str2sympy(y_eq)

eq_y_scope = deepcopy(self.pars_update)
eq_y_scope.update(self.fixed_vars)
eq_y_scope.update(integration.get_mapping_scope())
eq_y_scope.update(sympy_analysis.get_mapping_scope())
eq_y_scope.update(self.y_eq_group['diff_eq'].func_scope)

argument = ', '.join(self.dvar_names + self.dpar_names)
@@ -588,7 +592,7 @@ class Base2DNeuronAnalyzer(Base1DNeuronAnalyzer):
# check
all_vars = set(eq_y_scope.keys())
all_vars.update(self.dvar_names + self.dpar_names)
if utils.contain_unknown_symbol(integration.sympy2str(dfydx_expr), all_vars):
if utils.contain_unknown_symbol(sympy_analysis.sympy2str(dfydx_expr), all_vars):
print('failed because contain unknown symbols.')
sympy_failed = True
else:
@@ -596,7 +600,7 @@ class Base2DNeuronAnalyzer(Base1DNeuronAnalyzer):
func_codes = [f'def dgdx({argument}):']
for expr in self.y_eq_group.sub_exprs[:-1]:
func_codes.append(f'{expr.var_name} = {expr.code}')
func_codes.append(f'return {integration.sympy2str(dfydx_expr)}')
func_codes.append(f'return {sympy_analysis.sympy2str(dfydx_expr)}')
exec(compile('\n '.join(func_codes), '', 'exec'), eq_y_scope)
dgdx = eq_y_scope['dgdx']
sympy_failed = False
@@ -629,11 +633,11 @@ class Base2DNeuronAnalyzer(Base1DNeuronAnalyzer):
y_var = self.dvar_names[1]
y_symbol = sympy.Symbol(y_var, real=True)
y_eq = self.target_eqs[y_var].sub_exprs[-1].code
y_eq = integration.str2sympy(y_eq)
y_eq = sympy_analysis.str2sympy(y_eq)

eq_y_scope = deepcopy(self.pars_update)
eq_y_scope.update(self.fixed_vars)
eq_y_scope.update(integration.get_mapping_scope())
eq_y_scope.update(sympy_analysis.get_mapping_scope())
eq_y_scope.update(self.y_eq_group['diff_eq'].func_scope)

argument = ', '.join(self.dvar_names + self.dpar_names)
@@ -652,7 +656,7 @@ class Base2DNeuronAnalyzer(Base1DNeuronAnalyzer):
# check
all_vars = set(eq_y_scope.keys())
all_vars.update(self.dvar_names + self.dpar_names)
if utils.contain_unknown_symbol(integration.sympy2str(dfydx_expr), all_vars):
if utils.contain_unknown_symbol(sympy_analysis.sympy2str(dfydx_expr), all_vars):
print('failed because contain unknown symbols.')
sympy_failed = True
else:
@@ -660,7 +664,7 @@ class Base2DNeuronAnalyzer(Base1DNeuronAnalyzer):
func_codes = [f'def dgdy({argument}):']
for expr in self.y_eq_group.sub_exprs[:-1]:
func_codes.append(f'{expr.var_name} = {expr.code}')
func_codes.append(f'return {integration.sympy2str(dfydx_expr)}')
func_codes.append(f'return {sympy_analysis.sympy2str(dfydx_expr)}')
exec(compile('\n '.join(func_codes), '', 'exec'), eq_y_scope)
dgdy = eq_y_scope['dgdy']
sympy_failed = False
@@ -712,12 +716,11 @@ class Base2DNeuronAnalyzer(Base1DNeuronAnalyzer):
"""

if 'fixed_point' not in self.analyzed_results:

vars_and_pars = ','.join(self.dvar_names[2:] + self.dpar_names)

eq_xy_scope = deepcopy(self.pars_update)
eq_xy_scope.update(self.fixed_vars)
eq_xy_scope.update(integration.get_mapping_scope())
eq_xy_scope.update(sympy_analysis.get_mapping_scope())
eq_xy_scope.update(self.x_eq_group['diff_eq'].func_scope)
eq_xy_scope.update(self.y_eq_group['diff_eq'].func_scope)

@@ -826,7 +829,7 @@ class Base2DNeuronAnalyzer(Base1DNeuronAnalyzer):
# f
eq_x_scope = deepcopy(self.pars_update)
eq_x_scope.update(self.fixed_vars)
eq_x_scope.update(integration.get_mapping_scope())
eq_x_scope.update(sympy_analysis.get_mapping_scope())
eq_x_scope.update(self.x_eq_group['diff_eq'].func_scope)
func_codes = [f'def f_x({",".join(self.dvar_names + self.dpar_names)}):']
func_codes.extend([f'{expr.var_name} = {expr.code}'
@@ -838,7 +841,7 @@ class Base2DNeuronAnalyzer(Base1DNeuronAnalyzer):
# g
eq_y_scope = deepcopy(self.pars_update)
eq_y_scope.update(self.fixed_vars)
eq_y_scope.update(integration.get_mapping_scope())
eq_y_scope.update(sympy_analysis.get_mapping_scope())
eq_y_scope.update(self.y_eq_group['diff_eq'].func_scope)
func_codes = [f'def g_y({",".join(self.dvar_names + self.dpar_names)}):']
func_codes.extend([f'{expr.var_name} = {expr.code}'
@@ -893,7 +896,7 @@ class Base2DNeuronAnalyzer(Base1DNeuronAnalyzer):
# x equation scope
eq_x_scope = deepcopy(self.pars_update)
eq_x_scope.update(self.fixed_vars)
eq_x_scope.update(integration.get_mapping_scope())
eq_x_scope.update(sympy_analysis.get_mapping_scope())
eq_x_scope.update(self.x_eq_group.diff_eq.func_scope)

argument = ','.join(self.dvar_names[2:] + self.dpar_names)
@@ -967,7 +970,7 @@ class Base2DNeuronAnalyzer(Base1DNeuronAnalyzer):
# y equation scope
eq_y_scope = deepcopy(self.pars_update)
eq_y_scope.update(self.fixed_vars)
eq_y_scope.update(integration.get_mapping_scope())
eq_y_scope.update(sympy_analysis.get_mapping_scope())
eq_y_scope.update(self.y_eq_group.diff_eq.func_scope)

argument = ','.join(self.dvar_names[2:] + self.dpar_names)
@@ -1031,11 +1034,11 @@ class Base2DNeuronAnalyzer(Base1DNeuronAnalyzer):
if not self.options.escape_sympy_solver:
y_symbol = sympy.Symbol(self.y_var, real=True)
code = self.target_eqs[self.y_var].sub_exprs[-1].code
y_eq = integration.str2sympy(code).expr
y_eq = sympy_analysis.str2sympy(code).expr

eq_y_scope = deepcopy(self.pars_update)
eq_y_scope.update(self.fixed_vars)
eq_y_scope.update(integration.get_mapping_scope())
eq_y_scope.update(sympy_analysis.get_mapping_scope())
eq_y_scope.update(self.y_eq_group['diff_eq'].func_scope)

argument = ', '.join(self.dvar_names + self.dpar_names)
@@ -1051,7 +1054,7 @@ class Base2DNeuronAnalyzer(Base1DNeuronAnalyzer):
y_by_x_in_y_eq = f()
if len(y_by_x_in_y_eq) > 1:
raise NotImplementedError('Do not support multiple values.')
y_by_x_in_y_eq = integration.sympy2str(y_by_x_in_y_eq[0])
y_by_x_in_y_eq = sympy_analysis.sympy2str(y_by_x_in_y_eq[0])

# check
all_vars = set(eq_y_scope.keys())
@@ -1105,11 +1108,11 @@ class Base2DNeuronAnalyzer(Base1DNeuronAnalyzer):
if not self.options.escape_sympy_solver:
y_symbol = sympy.Symbol(self.y_var, real=True)
code = self.x_eq_group.sub_exprs[-1].code
x_eq = integration.str2sympy(code).expr
x_eq = sympy_analysis.str2sympy(code).expr

eq_x_scope = deepcopy(self.pars_update)
eq_x_scope.update(self.fixed_vars)
eq_x_scope.update(integration.get_mapping_scope())
eq_x_scope.update(sympy_analysis.get_mapping_scope())
eq_x_scope.update(self.x_eq_group['diff_eq'].func_scope)

argument = ', '.join(self.dvar_names + self.dpar_names)
@@ -1126,7 +1129,7 @@ class Base2DNeuronAnalyzer(Base1DNeuronAnalyzer):
y_by_x_in_x_eq = f()
if len(y_by_x_in_x_eq) > 1:
raise NotImplementedError('Do not support multiple values.')
y_by_x_in_x_eq = integration.sympy2str(y_by_x_in_x_eq[0])
y_by_x_in_x_eq = sympy_analysis.sympy2str(y_by_x_in_x_eq[0])

all_vars = set(eq_x_scope.keys())
all_vars.update(self.dvar_names + self.dpar_names)
@@ -1179,11 +1182,11 @@ class Base2DNeuronAnalyzer(Base1DNeuronAnalyzer):
if not self.options.escape_sympy_solver:
x_symbol = sympy.Symbol(self.x_var, real=True)
code = self.target_eqs[self.y_var].sub_exprs[-1].code
y_eq = integration.str2sympy(code).expr
y_eq = sympy_analysis.str2sympy(code).expr

eq_y_scope = deepcopy(self.pars_update)
eq_y_scope.update(self.fixed_vars)
eq_y_scope.update(integration.get_mapping_scope())
eq_y_scope.update(sympy_analysis.get_mapping_scope())
eq_y_scope.update(self.y_eq_group['diff_eq'].func_scope)

argument = ', '.join(self.dvar_names + self.dpar_names)
@@ -1198,7 +1201,7 @@ class Base2DNeuronAnalyzer(Base1DNeuronAnalyzer):
x_by_y_in_y_eq = f()
if len(x_by_y_in_y_eq) > 1:
raise NotImplementedError('Do not support multiple values.')
x_by_y_in_y_eq = integration.sympy2str(x_by_y_in_y_eq[0])
x_by_y_in_y_eq = sympy_analysis.sympy2str(x_by_y_in_y_eq[0])

# check
all_vars = set(eq_y_scope.keys())
@@ -1252,11 +1255,11 @@ class Base2DNeuronAnalyzer(Base1DNeuronAnalyzer):
if not self.options.escape_sympy_solver:
x_symbol = sympy.Symbol(self.x_var, real=True)
code = self.x_eq_group.sub_exprs[-1].code
x_eq = integration.str2sympy(code).expr
x_eq = sympy_analysis.str2sympy(code).expr

eq_x_scope = deepcopy(self.pars_update)
eq_x_scope.update(self.fixed_vars)
eq_x_scope.update(integration.get_mapping_scope())
eq_x_scope.update(sympy_analysis.get_mapping_scope())
eq_x_scope.update(self.x_eq_group['diff_eq'].func_scope)

argument = ', '.join(self.dvar_names + self.dpar_names)
@@ -1271,7 +1274,7 @@ class Base2DNeuronAnalyzer(Base1DNeuronAnalyzer):
x_by_y_in_x_eq = f()
if len(x_by_y_in_x_eq) > 1:
raise NotImplementedError('Do not support multiple values.')
x_by_y_in_x_eq = integration.sympy2str(x_by_y_in_x_eq[0])
x_by_y_in_x_eq = sympy_analysis.sympy2str(x_by_y_in_x_eq[0])

# check
all_vars = set(eq_x_scope.keys())


+ 121
- 164
brainpy/analysis/bifurcation.py View File

@@ -1,16 +1,18 @@
# -*- coding: utf-8 -*-

import gc
from collections import OrderedDict

import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

from . import base
from . import utils
from .. import core
from .. import errors
from .. import profile
from brainpy import backend
from brainpy import errors
from brainpy.analysis import base
from brainpy.analysis import stability
from brainpy.analysis import utils
from brainpy.analysis.trajectory import Trajectory

__all__ = [
'Bifurcation',
@@ -36,18 +38,16 @@ class Bifurcation(object):
Parameters
----------

model : NeuType
An abstract neuronal type defined in BrainPy.
integrals : callable
The integral functions defined with `brainpy.odeint` or
`brainpy.sdeint` or `brainpy.ddeint`, or `brainpy.fdeint`.

"""

def __init__(self, model, target_pars, target_vars, fixed_vars=None, pars_update=None,
def __init__(self, integrals, target_pars, target_vars, fixed_vars=None, pars_update=None,
numerical_resolution=0.1, options=None):

# check "model"
if not isinstance(model, core.NeuType):
raise errors.ModelUseError('Bifurcation analysis only support neuron type model.')
self.model = model
self.model = utils.transform_integrals_to_model(integrals)

# check "target_pars"
if not isinstance(target_pars, dict):
@@ -80,13 +80,13 @@ class Bifurcation(object):
raise errors.ModelUseError('"pars_update" must be a dict the format of: '
'{"Par A": A_value, "Par B": B_value}')
for key in pars_update.keys():
if key not in model.step_scopes:
raise errors.ModelUseError(f'"{key}" is not a valid parameter in "{model.name}" model. ')
if (key not in self.model.scopes) and (key not in self.model.parameters):
raise errors.ModelUseError(f'"{key}" is not a valid parameter in "{integrals}". ')
self.pars_update = pars_update

# bifurcation analysis
if len(self.target_vars) == 1:
self.analyzer = _Bifurcation1D(model=model,
self.analyzer = _Bifurcation1D(model_or_integrals=self.model,
target_pars=target_pars,
target_vars=target_vars,
fixed_vars=fixed_vars,
@@ -95,7 +95,7 @@ class Bifurcation(object):
options=options)

elif len(self.target_vars) == 2:
self.analyzer = _Bifurcation2D(model=model,
self.analyzer = _Bifurcation2D(model_or_integrals=self.model,
target_pars=target_pars,
target_vars=target_vars,
fixed_vars=fixed_vars,
@@ -116,9 +116,9 @@ class _Bifurcation1D(base.Base1DNeuronAnalyzer):
Using this class, we can make co-dimension1 or co-dimension2 bifurcation analysis.
"""

def __init__(self, model, target_pars, target_vars, fixed_vars=None,
def __init__(self, model_or_integrals, target_pars, target_vars, fixed_vars=None,
pars_update=None, numerical_resolution=0.1, options=None):
super(_Bifurcation1D, self).__init__(model=model,
super(_Bifurcation1D, self).__init__(model_or_integrals=model_or_integrals,
target_pars=target_pars,
target_vars=target_vars,
fixed_vars=fixed_vars,
@@ -133,7 +133,7 @@ class _Bifurcation1D(base.Base1DNeuronAnalyzer):
f_dfdx = self.get_f_dfdx()

if len(self.target_pars) == 1:
container = {c: {'p': [], 'x': []} for c in utils.get_1d_classification()}
container = {c: {'p': [], 'x': []} for c in stability.get_1d_stability_types()}

# fixed point
par_a = self.dpar_names[0]
@@ -141,7 +141,7 @@ class _Bifurcation1D(base.Base1DNeuronAnalyzer):
xs = f_fixed_point(p)
for x in xs:
dfdx = f_dfdx(x, p)
fp_type = utils.stability_analysis(dfdx)
fp_type = stability.stability_analysis(dfdx)
container[fp_type]['p'].append(p)
container[fp_type]['x'].append(x)

@@ -149,7 +149,7 @@ class _Bifurcation1D(base.Base1DNeuronAnalyzer):
plt.figure(self.x_var)
for fp_type, points in container.items():
if len(points['x']):
plot_style = utils.plot_scheme[fp_type]
plot_style = stability.plot_scheme[fp_type]
plt.plot(points['p'], points['x'], '.', **plot_style, label=fp_type)
plt.xlabel(par_a)
plt.ylabel(self.x_var)
@@ -163,7 +163,7 @@ class _Bifurcation1D(base.Base1DNeuronAnalyzer):
plt.show()

elif len(self.target_pars) == 2:
container = {c: {'p0': [], 'p1': [], 'x': []} for c in utils.get_1d_classification()}
container = {c: {'p0': [], 'p1': [], 'x': []} for c in stability.get_1d_stability_types()}

# fixed point
for p0 in self.resolutions[self.dpar_names[0]]:
@@ -171,7 +171,7 @@ class _Bifurcation1D(base.Base1DNeuronAnalyzer):
xs = f_fixed_point(p0, p1)
for x in xs:
dfdx = f_dfdx(x, p0, p1)
fp_type = utils.stability_analysis(dfdx)
fp_type = stability.stability_analysis(dfdx)
container[fp_type]['p0'].append(p0)
container[fp_type]['p1'].append(p1)
container[fp_type]['x'].append(x)
@@ -181,7 +181,7 @@ class _Bifurcation1D(base.Base1DNeuronAnalyzer):
ax = fig.gca(projection='3d')
for fp_type, points in container.items():
if len(points['x']):
plot_style = utils.plot_scheme[fp_type]
plot_style = stability.plot_scheme[fp_type]
xs = points['p0']
ys = points['p1']
zs = points['x']
@@ -210,16 +210,15 @@ class _Bifurcation1D(base.Base1DNeuronAnalyzer):
raise NotImplementedError('1D phase plane do not support plot_limit_cycle_by_sim.')



class _Bifurcation2D(base.Base2DNeuronAnalyzer):
"""Bifurcation analysis of 2D system.

Using this class, we can make co-dimension1 or co-dimension2 bifurcation analysis.
"""

def __init__(self, model, target_pars, target_vars, fixed_vars=None,
def __init__(self, model_or_integrals, target_pars, target_vars, fixed_vars=None,
pars_update=None, numerical_resolution=0.1, options=None):
super(_Bifurcation2D, self).__init__(model=model,
super(_Bifurcation2D, self).__init__(model_or_integrals=model_or_integrals,
target_pars=target_pars,
target_vars=target_vars,
fixed_vars=fixed_vars,
@@ -228,9 +227,6 @@ class _Bifurcation2D(base.Base2DNeuronAnalyzer):
options=options)

self.fixed_points = None
self.limit_cycle_mon = None
self.limit_cycle_p0 = None
self.limit_cycle_p1 = None

def plot_bifurcation(self, show=False):
print('plot bifurcation ...')
@@ -242,14 +238,14 @@ class _Bifurcation2D(base.Base2DNeuronAnalyzer):
# bifurcation analysis of co-dimension 1
if len(self.target_pars) == 1:
container = {c: {'p': [], self.x_var: [], self.y_var: []}
for c in utils.get_2d_classification()}
for c in stability.get_2d_stability_types()}

# fixed point
for p in self.resolutions[self.dpar_names[0]]:
xs, ys = f_fixed_point(p)
for x, y in zip(xs, ys):
dfdx = f_jacobian(x, y, p)
fp_type = utils.stability_analysis(dfdx)
fp_type = stability.stability_analysis(dfdx)
container[fp_type]['p'].append(p)
container[fp_type][self.x_var].append(x)
container[fp_type][self.y_var].append(y)
@@ -259,7 +255,7 @@ class _Bifurcation2D(base.Base2DNeuronAnalyzer):
plt.figure(var)
for fp_type, points in container.items():
if len(points['p']):
plot_style = utils.plot_scheme[fp_type]
plot_style = stability.plot_scheme[fp_type]
plt.plot(points['p'], points[var], '.', **plot_style, label=fp_type)
plt.xlabel(self.dpar_names[0])
plt.ylabel(var)
@@ -275,7 +271,7 @@ class _Bifurcation2D(base.Base2DNeuronAnalyzer):
# bifurcation analysis of co-dimension 2
elif len(self.target_pars) == 2:
container = {c: {'p0': [], 'p1': [], self.x_var: [], self.y_var: []}
for c in utils.get_2d_classification()}
for c in stability.get_2d_stability_types()}

# fixed point
for p0 in self.resolutions[self.dpar_names[0]]:
@@ -283,7 +279,7 @@ class _Bifurcation2D(base.Base2DNeuronAnalyzer):
xs, ys = f_fixed_point(p0, p1)
for x, y in zip(xs, ys):
dfdx = f_jacobian(x, y, p0, p1)
fp_type = utils.stability_analysis(dfdx)
fp_type = stability.stability_analysis(dfdx)
container[fp_type]['p0'].append(p0)
container[fp_type]['p1'].append(p1)
container[fp_type][self.x_var].append(x)
@@ -295,7 +291,7 @@ class _Bifurcation2D(base.Base2DNeuronAnalyzer):
ax = fig.gca(projection='3d')
for fp_type, points in container.items():
if len(points['p0']):
plot_style = utils.plot_scheme[fp_type]
plot_style = stability.plot_scheme[fp_type]
xs = points['p0']
ys = points['p1']
zs = points[var]
@@ -320,7 +316,7 @@ class _Bifurcation2D(base.Base2DNeuronAnalyzer):

self.fixed_points = container
return container
def plot_limit_cycle_by_sim(self, var, duration=100, inputs=(), plot_style=None, tol=0.001, show=False):
print('plot limit cycle ...')

@@ -333,72 +329,55 @@ class _Bifurcation2D(base.Base2DNeuronAnalyzer):
if var not in [self.x_var, self.y_var]:
raise errors.AnalyzerError()

if self.limit_cycle_mon is None:
all_xs, all_ys, all_p0, all_p1 = [], [], [], []

# unstable node
unstable_node = self.fixed_points[utils._2D_UNSTABLE_NODE]
all_xs.extend(unstable_node[self.x_var])
all_ys.extend(unstable_node[self.y_var])
if len(self.dpar_names) == 1:
all_p0.extend(unstable_node['p'])
elif len(self.dpar_names) == 2:
all_p0.extend(unstable_node['p0'])
all_p1.extend(unstable_node['p1'])
else:
raise ValueError

# unstable focus
unstable_focus = self.fixed_points[utils._2D_UNSTABLE_FOCUS]
all_xs.extend(unstable_focus[self.x_var])
all_ys.extend(unstable_focus[self.y_var])
if len(self.dpar_names) == 1:
all_p0.extend(unstable_focus['p'])
elif len(self.dpar_names) == 2:
all_p0.extend(unstable_focus['p0'])
all_p1.extend(unstable_focus['p1'])
else:
raise ValueError

# format points
all_xs = np.array(all_xs)
all_ys = np.array(all_ys)
all_p0 = np.array(all_p0)
all_p1 = np.array(all_p1)
all_xs, all_ys, all_p0, all_p1 = [], [], [], []

# unstable node
unstable_node = self.fixed_points[stability.UNSTABLE_NODE_2D]
all_xs.extend(unstable_node[self.x_var])
all_ys.extend(unstable_node[self.y_var])
if len(self.dpar_names) == 1:
all_p0.extend(unstable_node['p'])
elif len(self.dpar_names) == 2:
all_p0.extend(unstable_node['p0'])
all_p1.extend(unstable_node['p1'])
else:
raise ValueError

# fixed variables
fixed_vars = dict()
for key, val in self.fixed_vars.items():
fixed_vars[key] = val
fixed_vars[self.dpar_names[0]] = all_p0
if len(self.dpar_names) == 2:
fixed_vars[self.dpar_names[1]] = all_p1

# initialize neuron group
length = all_xs.shape[0]
group = core.NeuGroup(self.model,
geometry=length,
monitors=self.dvar_names,
pars_update=self.pars_update)

# group initial state
group.ST[self.x_var] = all_xs
group.ST[self.y_var] = all_ys
for key, val in fixed_vars.items():
if key in group.ST:
group.ST[key] = val

# run neuron group
group.runner = core.TrajectoryRunner(group,
target_vars=self.dvar_names,
fixed_vars=fixed_vars)
group.run(duration=duration, inputs=inputs)

self.limit_cycle_mon = group.mon
self.limit_cycle_p0 = all_p0
self.limit_cycle_p1 = all_p1
# unstable focus
unstable_focus = self.fixed_points[stability.UNSTABLE_FOCUS_2D]
all_xs.extend(unstable_focus[self.x_var])
all_ys.extend(unstable_focus[self.y_var])
if len(self.dpar_names) == 1:
all_p0.extend(unstable_focus['p'])
elif len(self.dpar_names) == 2:
all_p0.extend(unstable_focus['p0'])
all_p1.extend(unstable_focus['p1'])
else:
length = self.limit_cycle_mon[var].shape[1]
raise ValueError

# format points
all_xs = np.array(all_xs)
all_ys = np.array(all_ys)
all_p0 = np.array(all_p0)
all_p1 = np.array(all_p1)

# fixed variables
fixed_vars = dict()
for key, val in self.fixed_vars.items():
fixed_vars[key] = val
fixed_vars[self.dpar_names[0]] = all_p0
if len(self.dpar_names) == 2:
fixed_vars[self.dpar_names[1]] = all_p1

# initialize neuron group
length = all_xs.shape[0]
traj_group = Trajectory(size=length,
integrals=self.model.integrals,
target_vars={self.x_var: all_xs, self.y_var: all_ys},
fixed_vars=fixed_vars,
pars_update=self.pars_update,
scope=self.model.scopes)
traj_group.run(duration=duration)

# find limit cycles
limit_cycle_max = []
@@ -407,16 +386,16 @@ class _Bifurcation2D(base.Base2DNeuronAnalyzer):
p0_limit_cycle = []
p1_limit_cycle = []
for i in range(length):
data = self.limit_cycle_mon[var][:, i]
data = traj_group.mon[var][:, i]
max_index = utils.find_indexes_of_limit_cycle_max(data, tol=tol)
if max_index[0] != -1:
x_cycle = data[max_index[0]: max_index[1]]
limit_cycle_max.append(data[max_index[1]])
limit_cycle_min.append(x_cycle.min())
# limit_cycle.append(x_cycle)
p0_limit_cycle.append(self.limit_cycle_p0[i])
p0_limit_cycle.append(all_p0[i])
if len(self.dpar_names) == 2:
p1_limit_cycle.append(self.limit_cycle_p1[i])
p1_limit_cycle.append(all_p1[i])
self.fixed_points['limit_cycle'] = {var: {'max': limit_cycle_max,
'min': limit_cycle_min,
# 'cycle': limit_cycle
@@ -443,7 +422,9 @@ class _Bifurcation2D(base.Base2DNeuronAnalyzer):
if show:
plt.show()

del traj_group
gc.collect()


class FastSlowBifurcation(object):
"""Fast slow analysis analysis proposed by John Rinzel [1]_ [2]_ [3]_.
@@ -469,12 +450,10 @@ class FastSlowBifurcation(object):

"""

def __init__(self, model, fast_vars, slow_vars, fixed_vars=None,
def __init__(self, integrals, fast_vars, slow_vars, fixed_vars=None,
pars_update=None, numerical_resolution=0.1, options=None):
# check "model"
if not isinstance(model, core.NeuType):
raise errors.ModelUseError('FastSlowBifurcation only support neuron type model.')
self.model = model
self.model = utils.transform_integrals_to_model(integrals)

# check "fast_vars"
if not isinstance(fast_vars, dict):
@@ -495,6 +474,9 @@ class FastSlowBifurcation(object):
if len(slow_vars) > 2:
raise errors.ModelUseError("FastSlowBifurcation can only analyze the system with less "
"than two-variable slow subsystem.")
for key in self.slow_vars:
self.model.variables.remove(key)
self.model.parameters.append(key)

# check "fixed_vars"
if fixed_vars is None:
@@ -511,13 +493,13 @@ class FastSlowBifurcation(object):
raise errors.ModelUseError('"pars_update" must be a dict the format of: '
'{"Par A": A_value, "Par B": B_value}')
for key in pars_update.keys():
if key not in model.step_scopes:
raise errors.ModelUseError(f'"{key}" is not a valid parameter in "{model.name}" model. ')
if (key not in self.model.scopes) and (key not in self.model.parameters):
raise errors.ModelUseError(f'"{key}" is not a valid parameter in "{integrals}" model. ')
self.pars_update = pars_update

# bifurcation analysis
if len(self.fast_vars) == 1:
self.analyzer = _FastSlow1D(model=model,
self.analyzer = _FastSlow1D(model_or_integrals=self.model,
fast_vars=fast_vars,
slow_vars=slow_vars,
fixed_vars=fixed_vars,
@@ -526,7 +508,7 @@ class FastSlowBifurcation(object):
options=options)

elif len(self.fast_vars) == 2:
self.analyzer = _FastSlow2D(model=model,
self.analyzer = _FastSlow2D(model_or_integrals=self.model,
fast_vars=fast_vars,
slow_vars=slow_vars,
fixed_vars=fixed_vars,
@@ -548,9 +530,14 @@ class FastSlowBifurcation(object):


class _FastSlowTrajectory(object):
def __init__(self, model, fast_vars, slow_vars, fixed_vars=None,
def __init__(self, model_or_intgs, fast_vars, slow_vars, fixed_vars=None,
pars_update=None, **kwargs):
self.model = model
if isinstance(model_or_intgs, utils.DynamicModel):
self.model = model_or_intgs
elif (isinstance(model_or_intgs, (list, tuple)) and callable(model_or_intgs[0])) or callable(model_or_intgs):
self.model = utils.transform_integrals_to_model(model_or_intgs)
else:
raise ValueError
self.fast_vars = fast_vars
self.slow_vars = slow_vars
self.fixed_vars = fixed_vars
@@ -572,20 +559,7 @@ class _FastSlowTrajectory(object):
else:
self.slow_var_names = list(sorted(slow_vars.keys()))

# cannot update dynamical parameters
all_vars = self.fast_var_names + self.slow_var_names
self.traj_group = core.NeuGroup(model,
geometry=1,
monitors=all_vars,
pars_update=pars_update)
self.traj_group.runner = core.TrajectoryRunner(self.traj_group,
target_vars=all_vars,
fixed_vars=fixed_vars)
self.traj_initial = {key: val[0] for key, val in self.traj_group.ST.items()
if not key.startswith('_')}
self.traj_net = core.Network(self.traj_group)

def plot_trajectory(self, initials, duration, plot_duration=None, inputs=(), show=False):
def plot_trajectory(self, initials, duration, plot_duration=None, show=False):
"""Plot trajectories according to the settings.

Parameters
@@ -606,8 +580,6 @@ class _FastSlowTrajectory(object):
The duration to plot. It can be a tuple with ``(start, end)``. It can
also be a list of tuple ``[(start1, end1), (start2, end2)]`` to specify
the plot duration for each initial value running.
inputs : tuple, list
The inputs to the model. Same with the ``inputs`` in ``NeuGroup.run()``
show : bool
Whether show or not.
"""
@@ -650,29 +622,15 @@ class _FastSlowTrajectory(object):
else:
assert len(plot_duration) == len(initials)

# 4. format the inputs
if len(inputs):
if isinstance(inputs[0], (tuple, list)):
inputs = [(self.traj_group,) + tuple(input) for input in inputs]
elif isinstance(inputs[0], str):
inputs = [(self.traj_group,) + tuple(inputs)]
else:
raise errors.ModelUseError()

# 5. run the network
for init_i, initial in enumerate(initials):
# 5.1 set the initial value
for key, val in self.traj_initial.items():
self.traj_group.ST[key] = val
for key in all_vars:
self.traj_group.ST[key] = initial[key]
for key, val in self.fixed_vars.items():
if key in self.traj_group.ST:
self.traj_group.ST[key] = val

# 5.2 run the model
self.traj_net.run(duration=duration[init_i], inputs=inputs,
report=False, data_to_host=True, verbose=False)
traj_group = Trajectory(size=1,
integrals=self.model.integrals,
target_vars=initial,
fixed_vars=self.fixed_vars,
pars_update=self.pars_update,
scope=self.model.scopes)
traj_group.run(duration=duration[init_i], report=False)

# 5.3 legend
legend = f'$traj_{init_i}$: '
@@ -681,13 +639,13 @@ class _FastSlowTrajectory(object):
legend = legend[:-2]

# 5.4 trajectory
start = int(plot_duration[init_i][0] / profile.get_dt())
end = int(plot_duration[init_i][1] / profile.get_dt())
start = int(plot_duration[init_i][0] / backend.get_dt())
end = int(plot_duration[init_i][1] / backend.get_dt())

# 5.5 visualization
for var_name in self.fast_var_names:
s0 = self.traj_group.mon[self.slow_var_names[0]][start: end, 0]
fast = self.traj_group.mon[var_name][start: end, 0]
s0 = traj_group.mon[self.slow_var_names[0]][start: end, 0]
fast = traj_group.mon[var_name][start: end, 0]

fig = plt.figure(var_name)
if len(self.slow_var_names) == 1:
@@ -700,7 +658,7 @@ class _FastSlowTrajectory(object):

elif len(self.slow_var_names) == 2:
fig.gca(projection='3d')
s1 = self.traj_group.mon[self.slow_var_names[1]][start: end, 0]
s1 = traj_group.mon[self.slow_var_names[1]][start: end, 0]
plt.plot(s0, s1, fast, label=legend)
else:
raise errors.AnalyzerError
@@ -730,18 +688,17 @@ class _FastSlowTrajectory(object):
plt.show()



class _FastSlow1D(_Bifurcation1D):
def __init__(self, model, fast_vars, slow_vars, fixed_vars=None,
def __init__(self, model_or_integrals, fast_vars, slow_vars, fixed_vars=None,
pars_update=None, numerical_resolution=0.1, options=None):
super(_FastSlow1D, self).__init__(model=model,
super(_FastSlow1D, self).__init__(model_or_integrals=model_or_integrals,
target_pars=slow_vars,
target_vars=fast_vars,
fixed_vars=fixed_vars,
pars_update=pars_update,
numerical_resolution=numerical_resolution,
options=options)
self.traj = _FastSlowTrajectory(model=model,
self.traj = _FastSlowTrajectory(model_or_intgs=model_or_integrals,
fast_vars=fast_vars,
slow_vars=slow_vars,
fixed_vars=fixed_vars,
@@ -760,16 +717,16 @@ class _FastSlow1D(_Bifurcation1D):


class _FastSlow2D(_Bifurcation2D):
def __init__(self, model, fast_vars, slow_vars, fixed_vars=None,
def __init__(self, model_or_integrals, fast_vars, slow_vars, fixed_vars=None,
pars_update=None, numerical_resolution=0.1, options=None):
super(_FastSlow2D, self).__init__(model=model,
super(_FastSlow2D, self).__init__(model_or_integrals=model_or_integrals,
target_pars=slow_vars,
target_vars=fast_vars,
fixed_vars=fixed_vars,
pars_update=pars_update,
numerical_resolution=numerical_resolution,
options=options)
self.traj = _FastSlowTrajectory(model=model,
self.traj = _FastSlowTrajectory(model_or_intgs=model_or_integrals,
fast_vars=fast_vars,
slow_vars=slow_vars,
fixed_vars=fixed_vars,


+ 121
- 119
brainpy/analysis/phase_plane.py View File

@@ -3,11 +3,12 @@
import matplotlib.pyplot as plt
import numpy as np

from . import base
from . import utils
from .. import core
from .. import errors
from .. import profile
from brainpy import backend
from brainpy import errors
from brainpy.analysis import base
from brainpy.analysis import stability
from brainpy.analysis import utils
from brainpy.analysis.trajectory import Trajectory

__all__ = [
'PhasePlane',
@@ -26,7 +27,7 @@ class PhasePlane(object):

Parameters
----------
model : NeuType
integrals : NeuType
The neuron model which defines the differential equations by using
`brainpy.integrate`.
target_vars : dict
@@ -76,23 +77,20 @@ class PhasePlane(object):

def __init__(
self,
model,
integrals,
target_vars,
fixed_vars=None,
pars_update=None,
numerical_resolution=0.1,
options=None,
):

# check "model"
if not isinstance(model, core.NeuType):
raise errors.ModelUseError('Phase plane analysis only support neuron type model.')
self.model = model
self.model = utils.transform_integrals_to_model(integrals)

# check "target_vars"
if not isinstance(target_vars, dict):
raise errors.ModelUseError('"target_vars" must a dict with the format of: '
'{"Variable A": [A_min, A_max], "Variable B": [B_min, B_max]}')
'{"Variable A": [A_min, A_max], "Variable B": [B_min, B_max]}')
self.target_vars = target_vars

# check "fixed_vars"
@@ -100,7 +98,7 @@ class PhasePlane(object):
fixed_vars = dict()
if not isinstance(fixed_vars, dict):
raise errors.ModelUseError('"fixed_vars" must be a dict with the format of: '
'{"Variable A": A_value, "Variable B": B_value}')
'{"Variable A": A_value, "Variable B": B_value}')
self.fixed_vars = fixed_vars

# check "pars_update"
@@ -108,22 +106,22 @@ class PhasePlane(object):
pars_update = dict()
if not isinstance(pars_update, dict):
raise errors.ModelUseError('"pars_update" must be a dict with the format of: '
'{"Par A": A_value, "Par B": B_value}')
'{"Par A": A_value, "Par B": B_value}')
for key in pars_update.keys():
if key not in model.step_scopes:
raise errors.ModelUseError(f'"{key}" is not a valid parameter in "{model.name}" model.')
if (key not in self.model.scopes) and (key not in self.model.parameters):
raise errors.ModelUseError(f'"{key}" is not a valid parameter in "{integrals}" model.')
self.pars_update = pars_update

# analyzer
if len(target_vars) == 1:
self.analyzer = _PhasePlane1D(model=model,
self.analyzer = _PhasePlane1D(model_or_integrals=self.model,
target_vars=target_vars,
fixed_vars=fixed_vars,
pars_update=pars_update,
numerical_resolution=numerical_resolution,
options=options)
elif len(target_vars) == 2:
self.analyzer = _PhasePlane2D(model=model,
self.analyzer = _PhasePlane2D(model_or_integrals=self.model,
target_vars=target_vars,
fixed_vars=fixed_vars,
pars_update=pars_update,
@@ -131,8 +129,8 @@ class PhasePlane(object):
options=options)
else:
raise errors.ModelUseError('BrainPy only support 1D/2D phase plane analysis. '
'Or, you can set "fixed_vars" to fix other variables, '
'then make 1D/2D phase plane analysis.')
'Or, you can set "fixed_vars" to fix other variables, '
'then make 1D/2D phase plane analysis.')

def plot_vector_field(self, *args, **kwargs):
"""Plot vector filed of a 2D/1D system."""
@@ -146,22 +144,73 @@ class PhasePlane(object):
"""Plot nullcline (only supported in 2D system)."""
self.analyzer.plot_nullcline(*args, **kwargs)

def plot_trajectory(self, *args, **kwargs):
"""Plot trajectories (only supported in 2D system)."""
self.analyzer.plot_trajectory(*args, **kwargs)
def plot_trajectory(self, initials, duration, plot_duration=None, axes='v-v', show=False):
"""Plot trajectories according to the settings.

def plot_limit_cycle_by_sim(self, *args, **kwargs):
"""Find the limit cycles through the simulation, and then plot."""
self.analyzer.plot_limit_cycle_by_sim(*args, **kwargs)
Parameters
----------
initials : list, tuple, dict
The initial value setting of the targets. It can be a tuple/list of floats to specify
each value of dynamical variables (for example, ``(a, b)``). It can also be a
tuple/list of tuple to specify multiple initial values (for example,
``[(a1, b1), (a2, b2)]``).
duration : int, float, tuple, list
The running duration. Same with the ``duration`` in ``NeuGroup.run()``.
It can be a int/float (``t_end``) to specify the same running end time,
or it can be a tuple/list of int/float (``(t_start, t_end)``) to specify
the start and end simulation time. Or, it can be a list of tuple
(``[(t1_start, t1_end), (t2_start, t2_end)]``) to specify the specific
start and end simulation time for each initial value.
plot_duration : tuple, list, optional
The duration to plot. It can be a tuple with ``(start, end)``. It can
also be a list of tuple ``[(start1, end1), (start2, end2)]`` to specify
the plot duration for each initial value running.
axes : str
The axes to plot. It can be:

- 'v-v'
Plot the trajectory in the 'x_var'-'y_var' axis.
- 't-v'
Plot the trajectory in the 'time'-'var' axis.
show : bool
Whether show or not.
"""
self.analyzer.plot_trajectory(initials=initials,
duration=duration,
plot_duration=plot_duration,
axes=axes,
show=show)

def plot_limit_cycle_by_sim(self, initials, duration, tol=0.001, show=False):
"""Plot limit cycles according to the settings.

Parameters
----------
initials : list, tuple
The initial value setting of the targets. It can be a tuple/list of floats to specify
each value of dynamical variables (for example, ``(a, b)``). It can also be a
tuple/list of tuple to specify multiple initial values (for example,
``[(a1, b1), (a2, b2)]``).
duration : int, float, tuple, list
The running duration. Same with the ``duration`` in ``NeuGroup.run()``.
It can be a int/float (``t_end``) to specify the same running end time,
or it can be a tuple/list of int/float (``(t_start, t_end)``) to specify
the start and end simulation time. Or, it can be a list of tuple
(``[(t1_start, t1_end), (t2_start, t2_end)]``) to specify the specific
start and end simulation time for each initial value.
show : bool
Whether show or not.
"""
self.analyzer.plot_limit_cycle_by_sim(initials=initials,
duration=duration,
tol=tol,
show=show)


class _PhasePlane1D(base.Base1DNeuronAnalyzer):
"""Phase plane analyzer for 1D system.
"""

def __init__(self, *args, **kwargs):
super(_PhasePlane1D, self).__init__(*args, **kwargs)

def plot_vector_field(self, show=False):
"""Plot the vector filed.

@@ -182,7 +231,7 @@ class _PhasePlane1D(base.Base1DNeuronAnalyzer):
y_val = self.get_f_dx()(self.resolutions[self.x_var])
except TypeError:
raise errors.ModelUseError('Missing variables. Please check and set missing '
'variables to "fixed_vars".')
'variables to "fixed_vars".')

# 2. visualization
label = f"d{self.x_var}dt"
@@ -192,7 +241,8 @@ class _PhasePlane1D(base.Base1DNeuronAnalyzer):

plt.xlabel(self.x_var)
plt.ylabel(label)
plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=(self.options.lim_scale - 1.) / 2))
plt.xlim(*utils.rescale(self.target_vars[self.x_var],
scale=(self.options.lim_scale - 1.) / 2))
plt.legend()
if show:
plt.show()
@@ -219,18 +269,18 @@ class _PhasePlane1D(base.Base1DNeuronAnalyzer):

# 2. stability analysis
x_values = f_fixed_point()
container = {a: [] for a in utils.get_1d_classification()}
container = {a: [] for a in stability.get_1d_stability_types()}
for i in range(len(x_values)):
x = x_values[i]
dfdx = f_dfdx(x)
fp_type = utils.stability_analysis(dfdx)
fp_type = stability.stability_analysis(dfdx)
print(f"Fixed point #{i + 1} at {self.x_var}={x} is a {fp_type}.")
container[fp_type].append(x)

# 3. visualization
for fp_type, points in container.items():
if len(points):
plot_style = utils.plot_scheme[fp_type]
plot_style = stability.plot_scheme[fp_type]
plt.plot(points, [0] * len(points), '.',
markersize=20, **plot_style, label=fp_type)
plt.legend()
@@ -251,27 +301,8 @@ class _PhasePlane1D(base.Base1DNeuronAnalyzer):

class _PhasePlane2D(base.Base2DNeuronAnalyzer):
"""Phase plane analyzer for 2D system.

"""

def __init__(self, *args, **kwargs):
super(_PhasePlane2D, self).__init__(*args, **kwargs)

# runner for trajectory
# ---------------------

# cannot update dynamical parameters
self.traj_group = core.NeuGroup(self.model,
geometry=1,
monitors=self.dvar_names,
pars_update=self.pars_update)
self.traj_group.runner = core.TrajectoryRunner(self.traj_group,
target_vars=self.dvar_names,
fixed_vars=self.fixed_vars)
self.traj_initial = {key: val[0] for key, val in self.traj_group.ST.items()
if not key.startswith('_')}
self.traj_net = core.Network(self.traj_group)

def plot_vector_field(self, plot_method='streamplot', plot_style=None, show=False):
"""Plot the vector field.

@@ -309,14 +340,14 @@ class _PhasePlane2D(base.Base2DNeuronAnalyzer):
dx = self.get_f_dx()(X, Y)
except TypeError:
raise errors.ModelUseError('Missing variables. Please check and set missing '
'variables to "fixed_vars".')
'variables to "fixed_vars".')

# dy
try:
dy = self.get_f_dy()(X, Y)
except TypeError:
raise errors.ModelUseError('Missing variables. Please check and set missing '
'variables to "fixed_vars".')
'variables to "fixed_vars".')

# vector field
if plot_method == 'quiver':
@@ -374,11 +405,11 @@ class _PhasePlane2D(base.Base2DNeuronAnalyzer):

# stability analysis
# ------------------
container = {a: {'x': [], 'y': []} for a in utils.get_2d_classification()}
container = {a: {'x': [], 'y': []} for a in stability.get_2d_stability_types()}
for i in range(len(x_values)):
x = x_values[i]
y = y_values[i]
fp_type = utils.stability_analysis(f_jacobian(x, y))
fp_type = stability.stability_analysis(f_jacobian(x, y))
print(f"Fixed point #{i + 1} at {self.x_var}={x}, {self.y_var}={y} is a {fp_type}.")
container[fp_type]['x'].append(x)
container[fp_type]['y'].append(y)
@@ -387,7 +418,7 @@ class _PhasePlane2D(base.Base2DNeuronAnalyzer):
# -------------
for fp_type, points in container.items():
if len(points['x']):
plot_style = utils.plot_scheme[fp_type]
plot_style = stability.plot_scheme[fp_type]
plt.plot(points['x'], points['y'], '.', markersize=20, **plot_style, label=fp_type)
plt.legend()
if show:
@@ -442,7 +473,7 @@ class _PhasePlane2D(base.Base2DNeuronAnalyzer):
y_values_in_y_eq = y_by_x['f'](xs)
except TypeError:
raise errors.ModelUseError('Missing variables. Please check and set missing '
'variables to "fixed_vars".')
'variables to "fixed_vars".')
x_values_in_y_eq = xs
plt.plot(xs, y_values_in_y_eq, **y_style, label=f"{self.y_var} nullcline")

@@ -453,7 +484,7 @@ class _PhasePlane2D(base.Base2DNeuronAnalyzer):
x_values_in_y_eq = x_by_y['f'](ys)
except TypeError:
raise errors.ModelUseError('Missing variables. Please check and set missing '
'variables to "fixed_vars".')
'variables to "fixed_vars".')
y_values_in_y_eq = ys
plt.plot(x_values_in_y_eq, ys, **y_style, label=f"{self.y_var} nullcline")
else:
@@ -476,7 +507,7 @@ class _PhasePlane2D(base.Base2DNeuronAnalyzer):
y_values_in_x_eq = y_by_x['f'](xs)
except TypeError:
raise errors.ModelUseError('Missing variables. Please check and set missing '
'variables to "fixed_vars".')
'variables to "fixed_vars".')
x_values_in_x_eq = xs
plt.plot(xs, y_values_in_x_eq, **x_style, label=f"{self.x_var} nullcline")

@@ -487,7 +518,7 @@ class _PhasePlane2D(base.Base2DNeuronAnalyzer):
x_values_in_x_eq = x_by_y['f'](ys)
except TypeError:
raise errors.ModelUseError('Missing variables. Please check and set missing '
'variables to "fixed_vars".')
'variables to "fixed_vars".')
y_values_in_x_eq = ys
plt.plot(x_values_in_x_eq, ys, **x_style, label=f"{self.x_var} nullcline")
else:
@@ -515,12 +546,12 @@ class _PhasePlane2D(base.Base2DNeuronAnalyzer):
return {self.x_eq_group.func_name: (x_values_in_x_eq, y_values_in_x_eq),
self.y_eq_group.func_name: (x_values_in_y_eq, y_values_in_y_eq)}

def plot_trajectory(self, initials, duration, plot_duration=None, inputs=(), axes='v-v', show=False):
def plot_trajectory(self, initials, duration, plot_duration=None, axes='v-v', show=False):
"""Plot trajectories according to the settings.

Parameters
----------
initials : list, tuple
initials : list, tuple, dict
The initial value setting of the targets. It can be a tuple/list of floats to specify
each value of dynamical variables (for example, ``(a, b)``). It can also be a
tuple/list of tuple to specify multiple initial values (for example,
@@ -536,8 +567,6 @@ class _PhasePlane2D(base.Base2DNeuronAnalyzer):
The duration to plot. It can be a tuple with ``(start, end)``. It can
also be a list of tuple ``[(start1, end1), (start2, end2)]`` to specify
the plot duration for each initial value running.
inputs : tuple, list
The inputs to the model. Same with the ``inputs`` in ``NeuGroup.run()``
axes : str
The axes to plot. It can be:

@@ -585,29 +614,17 @@ class _PhasePlane2D(base.Base2DNeuronAnalyzer):
else:
assert len(plot_duration) == len(initials)

# 4. format the inputs
if len(inputs):
if isinstance(inputs[0], (tuple, list)):
inputs = [(self.traj_group, ) + tuple(input) for input in inputs]
elif isinstance(inputs[0], str):
inputs = [(self.traj_group, ) + tuple(inputs)]
else:
raise errors.ModelUseError()

# 5. run the network
for init_i, initial in enumerate(initials):
# 5.1 set the initial value
for key, val in self.traj_initial.items():
self.traj_group.ST[key] = val
for key in self.dvar_names:
self.traj_group.ST[key] = initial[key]
for key, val in self.fixed_vars.items():
if key in self.traj_group.ST:
self.traj_group.ST[key] = val
traj_group = Trajectory(size=1,
integrals=self.model.integrals,
target_vars=initial,
fixed_vars=self.fixed_vars,
pars_update=self.pars_update,
scope=self.model.scopes)

# 5.2 run the model
self.traj_net.run(duration=duration[init_i], inputs=inputs,
report=False, data_to_host=True, verbose=False)
traj_group.run(duration=duration[init_i], report=False, )

# 5.3 legend
legend = f'$traj_{init_i}$: '
@@ -616,21 +633,21 @@ class _PhasePlane2D(base.Base2DNeuronAnalyzer):
legend = legend[:-2]

# 5.4 trajectory
start = int(plot_duration[init_i][0] / profile.get_dt())
end = int(plot_duration[init_i][1] / profile.get_dt())
start = int(plot_duration[init_i][0] / backend.get_dt())
end = int(plot_duration[init_i][1] / backend.get_dt())

# 5.5 visualization
if axes == 'v-v':
lines = plt.plot(self.traj_group.mon[self.x_var][start: end, 0],
self.traj_group.mon[self.y_var][start: end, 0],
label=legend)
lines = plt.plot(traj_group.mon[self.x_var][start: end, 0],
traj_group.mon[self.y_var][start: end, 0],
label=legend)
utils.add_arrow(lines[0])
else:
plt.plot(self.traj_group.mon.ts[start: end],
self.traj_group.mon[self.x_var][start: end, 0],
plt.plot(traj_group.mon.ts[start: end],
traj_group.mon[self.x_var][start: end, 0],
label=legend + f', {self.x_var}')
plt.plot(self.traj_group.mon.ts[start: end],
self.traj_group.mon[self.y_var][start: end, 0],
plt.plot(traj_group.mon.ts[start: end],
traj_group.mon[self.y_var][start: end, 0],
label=legend + f', {self.y_var}')

# 6. visualization
@@ -647,7 +664,7 @@ class _PhasePlane2D(base.Base2DNeuronAnalyzer):
if show:
plt.show()

def plot_limit_cycle_by_sim(self, initials, duration, inputs=(), tol=0.001, show=False):
def plot_limit_cycle_by_sim(self, initials, duration, tol=0.001, show=False):
"""Plot trajectories according to the settings.

Parameters
@@ -664,8 +681,6 @@ class _PhasePlane2D(base.Base2DNeuronAnalyzer):
the start and end simulation time. Or, it can be a list of tuple
(``[(t1_start, t1_end), (t2_start, t2_end)]``) to specify the specific
start and end simulation time for each initial value.
inputs : tuple, list
The inputs to the model. Same with the ``inputs`` in ``NeuGroup.run()``
show : bool
Whether show or not.
"""
@@ -694,31 +709,19 @@ class _PhasePlane2D(base.Base2DNeuronAnalyzer):
else:
assert len(duration) == len(initials)

# 4. format the inputs
if len(inputs):
if isinstance(inputs[0], (tuple, list)):
inputs = [(self.traj_group, ) + tuple(input) for input in inputs]
elif isinstance(inputs[0], str):
inputs = [(self.traj_group, ) + tuple(inputs)]
else:
raise errors.ModelUseError()

# 5. run the network
for init_i, initial in enumerate(initials):
# 5.1 set the initial value
for key, val in self.traj_initial.items():
self.traj_group.ST[key] = val
for key in self.dvar_names:
self.traj_group.ST[key] = initial[key]
for key, val in self.fixed_vars.items():
if key in self.traj_group.ST:
self.traj_group.ST[key] = val
traj_group = Trajectory(size=1,
integrals=self.model.integrals,
target_vars=initial,
fixed_vars=self.fixed_vars,
pars_update=self.pars_update,
scope=self.model.scopes)

# 5.2 run the model
self.traj_net.run(duration=duration[init_i], inputs=inputs,
report=False, data_to_host=True, verbose=False)
x_data = self.traj_group.mon[self.x_var][:, 0]
y_data = self.traj_group.mon[self.y_var][:, 0]
traj_group.run(duration=duration[init_i], report=False, )
x_data = traj_group.mon[self.x_var][:, 0]
y_data = traj_group.mon[self.y_var][:, 0]
max_index = utils.find_indexes_of_limit_cycle_max(x_data, tol=tol)
if max_index[0] != -1:
x_cycle = x_data[max_index[0]: max_index[1]]
@@ -739,4 +742,3 @@ class _PhasePlane2D(base.Base2DNeuronAnalyzer):

if show:
plt.show()


+ 18
- 9
brainpy/analysis/solver.py View File

@@ -1,10 +1,14 @@
# -*- coding: utf-8 -*-

from collections import namedtuple
from importlib import import_module

import numba as nb
import numpy as np
from scipy.optimize import shgo

try:
numba = import_module('numba')
except ModuleNotFoundError:
numba = None

__all__ = [
'brentq',
@@ -18,9 +22,7 @@ _ECONVERR = -1
results = namedtuple('results', ['root', 'function_calls', 'iterations', 'converged'])


@nb.njit
def brentq(f, a, b, args=(), xtol=2e-12, maxiter=100,
rtol=4 * np.finfo(float).eps):
def brentq(f, a, b, args=(), xtol=2e-14, maxiter=200, rtol=4 * np.finfo(float).eps):
"""
Find a root of a function in a bracketing interval using Brent's method
adapted from Scipy's brentq.
@@ -147,12 +149,14 @@ def brentq(f, a, b, args=(), xtol=2e-12, maxiter=100,
if status == _ECONVERR:
raise RuntimeError("Failed to converge")

x, funcalls, iterations = root, funcalls, itr
# x, funcalls, iterations = root, funcalls, itr
return root, funcalls, itr


return x
if numba is not None:
brentq = numba.njit(brentq)


@nb.njit
def find_root_of_1d(f, f_points, args=(), tol=1e-8):
"""Find the roots of the given function by numerical methods.

@@ -190,7 +194,7 @@ def find_root_of_1d(f, f_points, args=(), tol=1e-8):
f_i += 2
else:
if not np.isnan(fr_sign) and fl_sign != fr_sign:
root = brentq(f, f_points[f_i - 1], f_points[f_i], args)
root, funcalls, itr = brentq(f, f_points[f_i - 1], f_points[f_i], args)
if abs(f(root, *args)) < tol:
roots.append(root)
fl_sign = fr_sign
@@ -199,6 +203,10 @@ def find_root_of_1d(f, f_points, args=(), tol=1e-8):
return roots


if numba is not None:
find_root_of_1d = numba.njit(find_root_of_1d)


def find_root_of_2d(f, x_bound, y_bound, args=(), shgo_args=None,
fl_tol=1e-6, xl_tol=1e-4, verbose=False):
"""Find the root of a two dimensional function.
@@ -249,6 +257,7 @@ def find_root_of_2d(f, x_bound, y_bound, args=(), shgo_args=None,
res : tuple
The roots.
"""
from scipy.optimize import shgo

# 1. shgo arguments
if shgo_args is None:


+ 161
- 0
brainpy/analysis/stability.py View File

@@ -0,0 +1,161 @@
# -*- coding: utf-8 -*-

import numpy as np

__all__ = [
'CENTER_MANIFOLD',
'SADDLE_NODE',
'STABLE_POINT_1D',
'UNSTABLE_POINT_1D',
'CENTER_2D',
'STABLE_NODE_2D',
'STABLE_FOCUS_2D',
'STABLE_STAR_2D',
'STABLE_DEGENERATE_2D',
'UNSTABLE_NODE_2D',
'UNSTABLE_FOCUS_2D',
'UNSTABLE_STAR_2D',
'UNSTABLE_DEGENERATE_2D',
'UNSTABLE_LINE_2D',

'get_1d_stability_types',
'get_2d_stability_types',

'stability_analysis',
]

CENTER_MANIFOLD = 'center manifold'
SADDLE_NODE = 'saddle node'
STABLE_POINT_1D = 'stable point'
UNSTABLE_POINT_1D = 'unstable point'
CENTER_2D = 'center'
STABLE_NODE_2D = 'stable node'
STABLE_FOCUS_2D = 'stable focus'
STABLE_STAR_2D = 'stable star'
STABLE_DEGENERATE_2D = 'stable degenerate'
UNSTABLE_NODE_2D = 'unstable node'
UNSTABLE_FOCUS_2D = 'unstable focus'
UNSTABLE_STAR_2D = 'unstable star'
UNSTABLE_DEGENERATE_2D = 'unstable degenerate'
UNSTABLE_LINE_2D = 'unstable line'

plot_scheme = {
STABLE_POINT_1D: {"color": 'tab:red'},
STABLE_NODE_2D: {"color": 'tab:red'},

UNSTABLE_POINT_1D: {"color": 'tab:olive'},
UNSTABLE_NODE_2D: {"color": 'tab:olive'},

STABLE_FOCUS_2D: {"color": 'tab:purple'},
UNSTABLE_FOCUS_2D: {"color": 'tab:cyan'},

SADDLE_NODE: {"color": 'tab:blue'},
CENTER_2D: {'color': 'lime'},
# _2D_UNIFORM_MOTION: {'color': 'red'},

CENTER_MANIFOLD: {'color': 'orangered'},
UNSTABLE_LINE_2D: {'color': 'dodgerblue'},

UNSTABLE_STAR_2D: {'color': 'green'},
STABLE_STAR_2D: {'color': 'orange'},

UNSTABLE_DEGENERATE_2D: {'color': 'springgreen'},
STABLE_DEGENERATE_2D: {'color': 'blueviolet'},
}


def get_1d_stability_types():
return [SADDLE_NODE, STABLE_POINT_1D, UNSTABLE_POINT_1D]


def get_2d_stability_types():
return [SADDLE_NODE, CENTER_2D, STABLE_NODE_2D, STABLE_FOCUS_2D,
STABLE_STAR_2D, CENTER_MANIFOLD, UNSTABLE_NODE_2D,
UNSTABLE_FOCUS_2D, UNSTABLE_STAR_2D, UNSTABLE_LINE_2D,
STABLE_DEGENERATE_2D, UNSTABLE_DEGENERATE_2D]


def get_3d_stability_types():
return []


def stability_analysis(derivatives):
"""Stability analysis for fixed point [1]_.

Parameters
----------
derivatives : float, tuple, list, np.ndarray
The derivative of the f.

Returns
-------
fp_type : str
The type of the fixed point.

References
----------

.. [1] http://www.egwald.ca/nonlineardynamics/twodimensionaldynamics.php

"""
if np.size(derivatives) == 1:
if derivatives == 0:
return SADDLE_NODE
elif derivatives > 0:
return STABLE_POINT_1D
else:
return UNSTABLE_POINT_1D

elif np.size(derivatives) == 4:
a = derivatives[0][0]
b = derivatives[0][1]
c = derivatives[1][0]
d = derivatives[1][1]

# trace
p = a + d
# det
q = a * d - b * c

# judgement
if q < 0:
return SADDLE_NODE
elif q == 0:
if p <= 0:
return CENTER_MANIFOLD
else:
return UNSTABLE_LINE_2D
else:
# parabola
e = p * p - 4 * q
if p == 0:
return CENTER_2D
elif p > 0:
if e < 0:
return UNSTABLE_FOCUS_2D
elif e > 0:
return UNSTABLE_NODE_2D
else:
w = np.linalg.eigvals(derivatives)
if w[0] == w[1]:
return UNSTABLE_DEGENERATE_2D
else:
return UNSTABLE_STAR_2D
else:
if e < 0:
return STABLE_FOCUS_2D
elif e > 0:
return STABLE_NODE_2D
else:
w = np.linalg.eigvals(derivatives)
if w[0] == w[1]:
return STABLE_DEGENERATE_2D
else:
return STABLE_STAR_2D

elif np.size(derivatives) == 9:
pass

else:
raise ValueError('Unknown derivatives, only supports the jacobian '
'matrix with the shape of(1), (2, 2), or (3, 3).')

+ 109
- 0
brainpy/analysis/trajectory.py View File

@@ -0,0 +1,109 @@
# -*- coding: utf-8 -*-

from brainpy import backend
from brainpy.simulation.utils import run_model
from brainpy.tools import DictPlus

__all__ = [
'Trajectory',
]


class Trajectory(object):
def __init__(self, size, integrals, target_vars, fixed_vars,
pars_update, scope, show_code=False):
"""Trajectory Class.

Parameters
----------
size : int, tuple, list
The network size.
integrals : list of functions, function
The integral functions.
target_vars : dict
The target variables, with the format of "{key: initial_v}".
fixed_vars : dict
The fixed variables, with the format of "{key: fixed_v}".
pars_update : dict
The parameters to update.
scope :
"""
if callable(integrals):
integrals = (integrals,)
elif isinstance(integrals, (list, tuple)) and callable(integrals[0]):
integrals = tuple(integrals)
else:
raise ValueError
self.integrals = integrals
self.target_vars = target_vars
self.fixed_vars = fixed_vars
self.pars_update = pars_update
self.scope = {key: val for key, val in scope.items()}
self.show_code = show_code

# check network size
if isinstance(size, int):
size = (size,)
elif isinstance(size, (tuple, list)):
assert isinstance(size[0], int)
size = tuple(size)
else:
raise ValueError

# monitors, variables, parameters
self.mon = DictPlus()
self.vars_and_pars = DictPlus()
for key, val in target_vars.items():
self.vars_and_pars[key] = backend.ones(size) * val
self.mon[key] = backend.zeros((1,) + size)
for key, val in fixed_vars.items():
self.vars_and_pars[key] = backend.ones(size) * val
for key, val in pars_update.items():
self.vars_and_pars[key] = val
self.scope['VP'] = self.vars_and_pars
self.scope['MON'] = self.mon
self.scope['_fixed_vars'] = fixed_vars

code_lines = ['def run_func(_t, _i, _dt):']
for integral in integrals:
func_name = integral.__name__
self.scope[func_name] = integral
# update the step function
assigns = [f'VP["{var}"]' for var in integral.variables]
calls = [f'VP["{var}"]' for var in integral.variables]
calls.append('_t')
calls.extend([f'VP["{var}"]' for var in integral.parameters[1:]])
code_lines.append(f' {", ".join(assigns)} = {func_name}({", ".join(calls)})')
# reassign the fixed variables
for key, val in fixed_vars.items():
code_lines.append(f' VP["{key}"][:] = _fixed_vars["{key}"]')
# monitor the target variables
for key in target_vars.keys():
code_lines.append(f' MON["{key}"][_i] = VP["{key}"]')
# compile
code = '\n'.join(code_lines)
if show_code:
print(code)
print(self.scope)
print()

# recompile
exec(compile(code, '', 'exec'), self.scope)
self.run_func = self.scope['run_func']

def run(self, duration, report=False, report_percent=0.1):
if isinstance(duration, (int, float)):
duration = [0, duration]
elif isinstance(duration, (tuple, list)):
assert len(duration) == 2
duration = tuple(duration)
else:
raise ValueError

# get the times
times = backend.arange(duration[0], duration[1], backend.get_dt())
# reshape the monitor
for key in self.mon.keys():
self.mon[key] = backend.zeros((len(times),) + backend.shape(self.mon[key])[1:])
# run the model
run_model(run_func=self.run_func, times=times, report=report, report_percent=report_percent)

+ 102
- 127
brainpy/analysis/utils.py View File

@@ -1,18 +1,27 @@
# -*- coding: utf-8 -*-


import _thread as thread
import inspect
import threading

import numpy as np
from numba import njit
from numba.core.dispatcher import Dispatcher

from .. import backend
from .. import tools
from brainpy import errors
from brainpy import tools
from brainpy.integrators import ast_analysis
from brainpy.integrators import sympy_analysis

try:
import numba
from numba.core.dispatcher import Dispatcher
except ModuleNotFoundError:
numba = None
Dispatcher = None

__all__ = [
'stability_analysis',
'transform_integrals_to_model',
'DynamicModel',
'rescale',
'timeout',
'jit_compile',
@@ -20,121 +29,76 @@ __all__ = [
'contain_unknown_symbol',
]

_SADDLE_NODE = 'saddle-node'
_1D_STABLE_POINT = 'stable-point'
_1D_UNSTABLE_POINT = 'unstable-point'
_2D_CENTER = 'center'
_2D_STABLE_NODE = 'stable-node'
_2D_STABLE_FOCUS = 'stable-focus'
_2D_STABLE_STAR = 'stable-star'
_2D_STABLE_LINE = 'stable-line'
_2D_UNSTABLE_NODE = 'unstable-node'
_2D_UNSTABLE_FOCUS = 'unstable-focus'
_2D_UNSTABLE_STAR = 'star'
_2D_UNSTABLE_LINE = 'unstable-line'
_2D_UNIFORM_MOTION = 'uniform-motion'

plot_scheme = {
_1D_STABLE_POINT: {"color": 'tab:red'},
_2D_STABLE_NODE: {"color": 'tab:red'},

_1D_UNSTABLE_POINT: {"color": 'tab:olive'},
_2D_UNSTABLE_NODE: {"color": 'tab:olive'},

_2D_STABLE_FOCUS: {"color": 'tab:purple'},
_2D_UNSTABLE_FOCUS: {"color": 'tab:cyan'},

_SADDLE_NODE: {"color": 'tab:blue'},

_2D_STABLE_LINE: {'color': 'orangered'},
_2D_UNSTABLE_LINE: {'color': 'dodgerblue'},
_2D_CENTER: {'color': 'lime'},
_2D_UNSTABLE_STAR: {'color': 'green'},
_2D_STABLE_STAR: {'color': 'orange'},
_2D_UNIFORM_MOTION: {'color': 'red'},
}


def get_1d_classification():
return [_SADDLE_NODE, _1D_STABLE_POINT, _1D_UNSTABLE_POINT]


def get_2d_classification():
return [_SADDLE_NODE, _2D_CENTER, _2D_STABLE_NODE, _2D_STABLE_FOCUS,
_2D_STABLE_STAR, _2D_STABLE_LINE, _2D_UNSTABLE_NODE,
_2D_UNSTABLE_FOCUS, _2D_UNSTABLE_STAR, _2D_UNSTABLE_LINE,
_2D_UNIFORM_MOTION]


def stability_analysis(derivative):
"""Stability analysis for fixed point [1]_.

Parameters
----------
derivative : float, tuple, list, np.ndarray
The derivative of the f.

Returns
-------
fp_type : str
The type of the fixed point.

References
----------

.. [1] http://www.egwald.ca/nonlineardynamics/twodimensionaldynamics.php
def transform_integrals_to_model(integrals):
if callable(integrals):
integrals = [integrals]

"""
if np.size(derivative) == 1:
if derivative == 0:
return _SADDLE_NODE
elif derivative > 0:
return _1D_STABLE_POINT
else:
return _1D_UNSTABLE_POINT

elif np.size(derivative) == 4:
a = derivative[0][0]
b = derivative[0][1]
c = derivative[1][0]
d = derivative[1][1]

# trace
p = a + d
# det
q = a * d - b * c

# judgement
if q < 0:
return _SADDLE_NODE
elif q == 0:
if p < 0:
return _2D_STABLE_LINE
elif p > 0:
return _2D_UNSTABLE_LINE
else:
return _2D_UNIFORM_MOTION
all_scope = dict()
all_variables = set()
all_parameters = set()
analyzers = []
for integral in integrals:
# integral function
if Dispatcher is not None and isinstance(integral, Dispatcher):
integral = integral.py_func
else:
# parabola
e = p * p - 4 * q
if p == 0:
return _2D_CENTER
elif p > 0:
if e < 0:
return _2D_UNSTABLE_FOCUS
elif e == 0:
return _2D_UNSTABLE_STAR
else:
return _2D_UNSTABLE_NODE
else:
if e < 0:
return _2D_STABLE_FOCUS
elif e == 0:
return _2D_STABLE_STAR
else:
return _2D_STABLE_NODE
else:
raise ValueError('Unknown derivatives.')
integral = integral

# original function
f = integral.origin_f
if Dispatcher is not None and isinstance(f, Dispatcher):
f = f.py_func
func_name = f.__name__

# code scope
closure_vars = inspect.getclosurevars(f)
code_scope = dict(closure_vars.nonlocals)
code_scope.update(dict(closure_vars.globals))

# separate variables
analysis = ast_analysis.separate_variables(f)
variables_for_returns = analysis['variables_for_returns']
expressions_for_returns = analysis['expressions_for_returns']
for vi, (key, vars) in enumerate(variables_for_returns.items()):
variables = []
for v in vars:
if len(v) > 1:
raise ValueError('Cannot analyze must assignment code line.')
variables.append(v[0])
expressions = expressions_for_returns[key]
var_name = integral.variables[vi]
DE = sympy_analysis.SingleDiffEq(var_name=var_name,
variables=variables,
expressions=expressions,
derivative_expr=key,
scope=code_scope,
func_name=func_name)
analyzers.append(DE)

# others
for var in integral.variables:
if var in all_variables:
raise errors.ModelDefError(f'Variable {var} has been defined before. Cannot group '
f'this integral as a dynamic system.')
all_variables.add(var)
all_parameters.update(integral.parameters)
all_scope.update(code_scope)

return DynamicModel(integrals=integrals,
analyzers=analyzers,
variables=list(all_variables),
parameters=list(all_parameters),
scopes=all_scope)


class DynamicModel(object):
def __init__(self, integrals, analyzers, variables, parameters, scopes):
self.integrals = integrals
self.analyzers = analyzers
self.variables = variables
self.parameters = parameters
self.scopes = scopes


def rescale(min_max, scale=0.01):
@@ -176,7 +140,7 @@ def timeout(s):


def _jit(func):
if backend.func_in_numpy_or_math(func):
if sympy_analysis.func_in_numpy_or_math(func):
return func
if isinstance(func, Dispatcher):
return func
@@ -189,7 +153,7 @@ def _jit(func):
for k, v in code_scope.items():
# function
if callable(v):
if (not backend.func_in_numpy_or_math(v)) and (not isinstance(v, Dispatcher)):
if (not sympy_analysis.func_in_numpy_or_math(v)) and (not isinstance(v, Dispatcher)):
code_scope[k] = _jit(v)
modified = True

@@ -197,17 +161,19 @@ def _jit(func):
func_code = tools.deindent(tools.get_func_source(func))
exec(compile(func_code, '', "exec"), code_scope)
func = code_scope[func.__name__]
return njit(func)
return numba.njit(func)
else:
njit(func)
return numba.njit(func)


def jit_compile(scope, func_code, func_name):
# get function scope
if numba is None:
return
# get function scope
func_scope = dict()
for key, val in scope.items():
if callable(val):
if backend.func_in_numpy_or_math(val):
if sympy_analysis.func_in_numpy_or_math(val):
pass
elif isinstance(val, Dispatcher):
pass
@@ -217,7 +183,7 @@ def jit_compile(scope, func_code, func_name):

# compile function
exec(compile(func_code, '', 'exec'), func_scope)
return njit(func_scope[func_name])
return numba.njit(func_scope[func_name])


def contain_unknown_symbol(expr, scope):
@@ -271,7 +237,6 @@ def add_arrow(line, position=None, direction='right', size=15, color=None):
size=size)


@njit
def f1(arr, grad, tol):
condition = np.logical_and(grad[:-1] * grad[1:] <= 0, grad[:-1] >= 0)
indexes = np.where(condition)[0]
@@ -285,7 +250,10 @@ def f1(arr, grad, tol):
return np.array([-1, -1])


@njit
if numba is not None:
f1 = numba.njit(f1)


def f2(arr, grad, tol):
condition = np.logical_and(grad[:-1] * grad[1:] <= 0, grad[:-1] <= 0)
indexes = np.where(condition)[0]
@@ -299,6 +267,10 @@ def f2(arr, grad, tol):
return np.array([-1, -1])


if numba is not None:
f2 = numba.njit(f2)


def find_indexes_of_limit_cycle_max(arr, tol=0.001):
grad = np.gradient(arr)
return f1(arr, grad, tol)
@@ -309,7 +281,6 @@ def find_indexes_of_limit_cycle_min(arr, tol=0.001):
return f2(arr, grad, tol)


@njit
def _identity(a, b, tol=0.01):
if np.abs(a - b) < tol:
return True
@@ -317,6 +288,10 @@ def _identity(a, b, tol=0.01):
return False


if numba is not None:
_identity = numba.njit(_identity)


def find_indexes_of_limit_cycle_max2(arr, tol=0.001):
if np.ndim(arr) == 1:
grad = np.gradient(arr)


+ 221
- 2
brainpy/backend/__init__.py View File

@@ -1,4 +1,223 @@
# -*- coding: utf-8 -*-

from . import numpy_overload
from .utils import *
from types import ModuleType

from brainpy import errors
from .operators.bk_numpy import *
from .runners.general_runner import GeneralNodeRunner, GeneralNetRunner

_backend = 'numpy' # default backend is NumPy
_node_runner = None
_net_runner = None
_dt = 0.1

CLASS_KEYWORDS = ['self', 'cls']
NEEDED_OPS = ['normal', 'exp', 'matmul', 'sum',
'as_tensor', 'zeros', 'ones', 'arange',
'eye', 'vstack', 'reshape', 'shape', ]
SUPPORTED_BACKEND = {
'numba', 'numba-parallel', 'numba-cuda', 'jax', # JIT framework
'numpy', 'pytorch', 'tensorflow',
}
SYSTEM_KEYWORDS = ['_dt', '_t', '_i']


def set(backend=None, module_or_operations=None, node_runner=None, net_runner=None, dt=None):
"""Basic backend setting function.

Using this function, users can set the backend they prefer. For backend
which is unknown, users can provide `module_or_operations` to specify
the operations needed. Also, users can customize the node runner, or the
network runner, by providing the `node_runner` or `net_runner` keywords.
The default numerical precision `dt` can also be set by this function.

Parameters
----------
backend : str
The backend name.
module_or_operations : module, dict, optional
The module or the a dict containing necessary operations.
node_runner : GeneralNodeRunner
An instance of node runner.
net_runner : GeneralNetRunner
An instance of network runner.
dt : float
The numerical precision.
"""
if dt is not None:
set_dt(dt)

if (backend is None) or (_backend == backend):
return

global_vars = globals()
if backend == 'numpy':
from .operators import bk_numpy

node_runner = GeneralNodeRunner if node_runner is None else node_runner
net_runner = GeneralNetRunner if net_runner is None else net_runner
module_or_operations = bk_numpy if module_or_operations is None else module_or_operations

elif backend == 'pytorch':
from .operators import bk_pytorch

node_runner = GeneralNodeRunner if node_runner is None else node_runner
net_runner = GeneralNetRunner if net_runner is None else net_runner
module_or_operations = bk_pytorch if module_or_operations is None else module_or_operations

elif backend == 'tensorflow':
from .operators import bk_tensorflow

node_runner = GeneralNodeRunner if node_runner is None else node_runner
net_runner = GeneralNetRunner if net_runner is None else net_runner
module_or_operations = bk_tensorflow if module_or_operations is None else module_or_operations

elif backend == 'numba':
from .operators import bk_numba_cpu
from .runners.numba_cpu_runner import NumbaCPUNodeRunner, set_numba_profile

node_runner = NumbaCPUNodeRunner if node_runner is None else node_runner
module_or_operations = bk_numba_cpu if module_or_operations is None else module_or_operations
set_numba_profile(parallel=False)

elif backend == 'numba-parallel':
from .operators import bk_numba_cpu
from .runners.numba_cpu_runner import NumbaCPUNodeRunner, set_numba_profile

node_runner = NumbaCPUNodeRunner if node_runner is None else node_runner
module_or_operations = bk_numba_cpu if module_or_operations is None else module_or_operations
set_numba_profile(parallel=True)

elif backend == 'numba-cuda':
from .operators import bk_numba_cuda
from .runners.numba_cuda_runner import NumbaCudaNodeRunner

node_runner = NumbaCudaNodeRunner if node_runner is None else node_runner
module_or_operations = bk_numba_cuda if module_or_operations is None else module_or_operations

elif backend == 'jax':
from .operators import bk_jax
from .runners.jax_runner import JaxRunner

node_runner = JaxRunner if node_runner is None else node_runner
module_or_operations = bk_jax if module_or_operations is None else module_or_operations

else:
if module_or_operations is None:
raise errors.ModelUseError(f'Backend "{backend}" is unknown, '
f'please provide the "module_or_operations" '
f'to specify the necessary computation units.')
node_runner = GeneralNodeRunner if node_runner is None else node_runner

global_vars['_backend'] = backend
global_vars['_node_runner'] = node_runner
global_vars['_net_runner'] = net_runner
if isinstance(module_or_operations, ModuleType):
set_ops_from_module(module_or_operations)
elif isinstance(module_or_operations, dict):
set_ops(**module_or_operations)
else:
raise errors.ModelUseError('"module_or_operations" must be a module '
'or a dict of operations.')


def set_class_keywords(*args):
"""Set the keywords for class specification.

For example:

>>> class A(object):
>>> def __init__(cls):
>>> pass
>>> def f(self, ):
>>> pass

In this case, I use "cls" to denote the "self". So, I can set this by

>>> set_class_keywords('cls', 'self')

"""
global CLASS_KEYWORDS
CLASS_KEYWORDS = list(args)


def set_dt(dt):
"""Set the numerical integrator precision.

Parameters
----------
dt : float
Numerical integration precision.
"""
assert isinstance(dt, float)
global _dt
_dt = dt


def get_dt():
"""Get the numerical integrator precision.

Returns
-------
dt : float
Numerical integration precision.
"""
return _dt


def set_ops_from_module(module):
global_vars = globals()
for ops in NEEDED_OPS:
if not hasattr(module, ops):
raise ValueError(f'Operation "{ops}" is needed, but is not '
f'defined in module "{module}".')
global_vars[ops] = getattr(module, ops)


def set_ops(**kwargs):
global_vars = globals()
for key, value in kwargs.items():
if key not in NEEDED_OPS:
print(f'"{key}" is not a necessary operation.')
global_vars[key] = value


def get_backend():
"""Get the current backend name.

Returns
-------
backend : str
The name of the current backend name.
"""
return _backend


def get_node_runner():
"""Get the current node runner.

Returns
-------
node_runner
The node runner class.
"""
global _node_runner
if _node_runner is None:
from .runners.general_runner import GeneralNodeRunner
_node_runner = GeneralNodeRunner
return _node_runner


def get_net_runner():
"""Get the current network runner.

Returns
-------
net_runner
The network runner.
"""
global _net_runner
if _net_runner is None:
from .runners.general_runner import GeneralNetRunner
_net_runner = GeneralNetRunner
return _net_runner

+ 1
- 0
brainpy/backend/operators/__init__.py View File

@@ -0,0 +1 @@
# -*- coding: utf-8 -*-

+ 33
- 0
brainpy/backend/operators/bk_jax.py View File

@@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-

from jax import numpy
from jax import random

key = random.PRNGKey(0)


def set_seed(seed):
global key
key = random.PRNGKey(seed)


def normal(loc, scale, size):
return loc + scale * random.normal(key, shape=size)


reshape = numpy.reshape
exp = numpy.exp
sum = numpy.sum
zeros = numpy.zeros
eye = numpy.eye
matmul = numpy.matmul
vstack = numpy.vstack
arange = numpy.arange


def shape(x):
size = numpy.shape(x)
if len(size) == 0:
return (1,)
else:
return size

+ 23
- 0
brainpy/backend/operators/bk_numba_cpu.py View File

@@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-

import numpy as np

from . import bk_numba_overload

as_tensor = np.asarray
normal = np.random.normal
reshape = np.reshape
exp = np.exp
sum = np.sum
zeros = np.zeros
ones = np.ones
eye = np.eye
matmul = np.matmul
vstack = np.vstack
arange = np.arange
shape = np.shape
where = np.where


if __name__ == '__main__':
bk_numba_overload

+ 2
- 0
brainpy/backend/operators/bk_numba_cuda.py View File

@@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-


brainpy/backend/numpy_overload.py → brainpy/backend/operators/bk_numba_overload.py View File

@@ -2,10 +2,20 @@

import numba
import numpy

from numba.extending import overload


@overload(numpy.shape)
def shape_func(x):
if isinstance(x, (numba.types.Integer, numba.types.Float)):
def shape(x):
return (1,)

return shape
else:
return numpy.shape


@overload(numpy.cbrt)
def cbrt_func(x):
def cbrt(x):
@@ -66,10 +76,12 @@ def heaviside_func(x1, x2):
@overload(numpy.moveaxis)
def moveaxis_func(x, source, destination):
def moveaxis(x, source, destination):
shape = list(x.shape)
s = shape.pop(source)
shape.insert(destination, s)
return numpy.transpose(x, tuple(shape))
axes = list(range(len(x.shape)))
if source < 0: source = axes[source]
if destination < 0: destination = axes[destination]
s = axes.pop(source)
axes.insert(destination, s)
return numpy.transpose(x, tuple(axes))

return moveaxis

@@ -91,6 +103,7 @@ def swapaxes_func(x, axis1, axis2):
def logspace_func(start, stop, num=50, endpoint=True, base=10.0, dtype=None):
def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None):
return numpy.power(base, numpy.linspace(start, stop, num=num, endpoint=endpoint)).astype(dtype)

return logspace


@@ -149,4 +162,3 @@ def average(a, axis=None, weights=None):
return numpy.sum(a * weights, axis=axis) / sum(weights)

return func


+ 33
- 0
brainpy/backend/operators/bk_numpy.py View File

@@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-

import numpy as np

__all__ = [
'as_tensor',
'normal',
'reshape',
'shape',
'exp',
'sum',
'zeros',
'ones',
'eye',
'matmul',
'vstack',
'arange',
]


as_tensor = np.asarray
normal = np.random.normal
reshape = np.reshape
shape = np.shape
exp = np.exp
sum = np.sum
zeros = np.zeros
ones = np.ones
eye = np.eye
matmul = np.matmul
vstack = np.vstack
arange = np.arange


+ 28
- 0
brainpy/backend/operators/bk_pytorch.py View File

@@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-

"""
The PyTorch with the version of xx is needed.
"""


import torch

as_tensor = torch.tensor
normal = torch.normal
reshape = torch.reshape
exp = torch.exp
sum = torch.sum
zeros = torch.zeros
ones = torch.ones
eye = torch.eye
matmul = torch.matmul
vstack = torch.vstack
arange = torch.arange


def shape(x):
if isinstance(x, (int, float)):
return (1,)
else:
return x.size()


+ 31
- 0
brainpy/backend/operators/bk_tensorflow.py View File

@@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-


"""
The TensorFlow with the version of xx is needed.
"""

import tensorflow as tf

reshape = tf.reshape
exp = tf.math.exp
sum = tf.math.reduce_sum
zeros = tf.zeros
eye = tf.eye
matmul = tf.matmul
arange = tf.range


def vstack(values):
return tf.concat(values, axis=1)


def shape(x):
if isinstance(x, (int, float)):
return (1,)
else:
return x.shape()


def normal(loc, scale, size):
return tf.random.normal(size, loc, scale)

+ 339
- 0
brainpy/backend/operators/standard.py View File

@@ -0,0 +1,339 @@
# -*- coding: utf-8 -*-

"""
In this script, we establish the unified and standard
functions for computation backends.
"""

import numpy as np

__all__ = [
# random function
'normal',

# arithmetic operation
'sum',
'exp',
'matmul',

# tensor creation
'eye',
'zeros',
'ones',
'arange',
'as_tensor',

# tensor manipulation
'vstack',

# others
'shape',
'reshape',
]


def normal(loc=0.0, scale=1.0, size=None):
"""The normal operation. We expect "normal" function will behave like "numpy.random.normal"

Draw random samples from a normal (Gaussian) distribution.

Parameters
----------
loc : float or array_like of floats
Mean ("centre") of the distribution.
scale : float or array_like of floats
Standard deviation (spread or "width") of the distribution. Must be
non-negative.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``loc`` and ``scale`` are both scalars.
Otherwise, ``np.broadcast(loc, scale).size`` samples are drawn.

Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized normal distribution.
"""
pass


def sum(tensor, axis=None):
"""The sum operation. We expect "sum" function will behave like "numpy.sum"

Parameters
----------
tensor : array_like
The data to sum.
axis : None or int or tuple of ints, optional
Axis or axes along which a sum is performed. The default,
axis=None, will sum all of the elements of the input array. If
axis is negative it counts from the last to the first axis.
Returns
-------
sum_along_axis : ndarray
An array with the same shape as `a`, with the specified
axis removed. If `a` is a 0-d array, or if `axis` is None, a scalar
is returned. If an output array is specified, a reference to
`out` is returned.
Examples
--------
>>> sum([0.5, 1.5])
2.0
>>> sum([0.5, 0.7, 0.2, 1.5], dtype=np.int32)
1
>>> sum([[0, 1], [0, 5]])
6
>>> sum([[0, 1], [0, 5]], axis=0)
array([0, 6])
>>> sum([[0, 1], [0, 5]], axis=1)
array([1, 5])
>>> sum([[0, 1], [np.nan, 5]], where=[False, True], axis=1)
array([1., 5.])

If the accumulator is too small, overflow occurs:

>>> np.ones(128, dtype=np.int8).sum(dtype=np.int8)
-128

You can also start the sum with a value other than zero:

>>> sum([10], initial=5)
15
"""
pass


def exp(x):
"""The exp operation. We expect "exp" function will behave like "numpy.exp"

Parameters
----------
x : array_like
Input values.

Returns
-------
out : ndarray or scalar
Output array, element-wise exponential of `x`.
This is a scalar if `x` is a scalar.
"""
pass


def eye(N, *args, **kwargs):
"""The eye operation. We expect "eye" function will behave like "numpy.eye".

Return a 2-D array with ones on the diagonal and zeros elsewhere.

Parameters
----------
N : int
Number of rows in the output.

Returns
-------
I : tensor of shape (N,N)
A tensor where all elements are equal to zero, except for the `k`-th
diagonal, whose values are equal to one.
"""
pass


def matmul(x1, x2, *args, **kwargs):
"""The matmul operation. We expect "matmul" function will behave like "numpy.matmul".

Parameters
----------
x1, x2 : array_like
Input arrays, scalars not allowed.

Returns
-------
y : tensor
The matrix product of the inputs.
This is a scalar only when both x1, x2 are 1-d vectors.
"""
pass


def vstack(tup):
"""The vstack operation. We expect "vstack" function will behave like "numpy.vstack".

Stack arrays in sequence vertically (row wise).

Parameters
----------
tup : sequence of tensors
The arrays must have the same shape along all but the first axis.
1-D arrays must have the same length.

Returns
-------
stacked : tensor
The tensor formed by stacking the given tensors, will be at least 2-D.

Examples
--------
>>> a = np.array([1, 2, 3])
>>> b = np.array([2, 3, 4])
>>> np.vstack((a,b))
array([[1, 2, 3],
[2, 3, 4]])

>>> a = np.array([[1], [2], [3]])
>>> b = np.array([[2], [3], [4]])
>>> np.vstack((a,b))
array([[1],
[2],
[3],
[2],
[3],
[4]])
"""
pass


def zeros(shape, dtype=None):
"""The zeros operation. We expect "zeros" function will behave like "numpy.zeros".

Return a new array of given shape and type, filled with zeros.

Parameters
----------
shape : shape : int or tuple of ints
Shape of the new array, e.g., ``(2, 3)`` or ``2``.
dtype : data-type, optional
The desired data-type for the array, e.g., `int`. Default is
`float64`.

Returns
-------
out : tensors
Array of zeros with the given shape and dtype.
"""
pass


def ones(shape, dtype=None):
"""The ones operation. We expect "ones" function will behave like "numpy.ones".

Return a new array of given shape and type, filled with ones.

Parameters
----------
shape : shape : int or tuple of ints
Shape of the new array, e.g., ``(2, 3)`` or ``2``.
dtype : data-type, optional
The desired data-type for the array, e.g., `int`. Default is
`float64`.

Returns
-------
out : tensors
Array of ones with the given shape and dtype.
"""
pass


def arange(start=None, *args, **kwargs):
"""The arange operation. We expect "arange" function will behave like "numpy.arange".

Return evenly spaced values within a given interval.

Parameters
----------
start : number, optional
Start of interval. The interval includes this value. The default
start value is 0.
stop : number
End of interval. The interval does not include this value, except
in some cases where `step` is not an integer and floating point
round-off affects the length of `out`.
step : number, optional
Spacing between values. For any output `out`, this is the distance
between two adjacent values, ``out[i+1] - out[i]``. The default
step size is 1. If `step` is specified as a position argument,
`start` must also be given.
dtype : dtype
The type of the output array. If `dtype` is not given, infer the data
type from the other input arguments.

Returns
-------
arange : ndarray
Array of evenly spaced values.

For floating point arguments, the length of the result is
``ceil((stop - start)/step)``. Because of floating point overflow,
this rule may result in the last element of `out` being greater
than `stop`.
"""
pass


def reshape(a, newshape):
"""The reshape operation. We expect "reshape" function will behave like "numpy.reshape".

Gives a new shape to an array without changing its data.

Parameters
----------
a : array_like
Array to be reshaped.
newshape : int or tuple of ints
The new shape should be compatible with the original shape. If
an integer, then the result will be a 1-D array of that length.
One shape dimension can be -1. In this case, the value is
inferred from the length of the array and remaining dimensions.

Returns
-------
reshaped_array : ndarray
This will be a new view object if possible; otherwise, it will
be a copy. Note there is no guarantee of the *memory layout* (C- or
Fortran- contiguous) of the returned array.
"""
pass


def shape(a):
"""The shape operation. We expect "shape" function will behave like "numpy.shape".

Parameters
----------
a : array_like
Input array.

Returns
-------
shape : tuple of ints
The elements of the shape tuple give the lengths of the
corresponding array dimensions.
"""
pass


def as_tensor(a, dtype=None):
"""The as_tensor operation. We expect "as_tensor" function will behave like "numpy.asarray".

Parameters
----------
a : array_like
Input data, in any form that can be converted to an array. This
includes lists, lists of tuples, tuples, tuples of tuples, tuples
of lists and ndarrays.
dtype : data-type, optional
By default, the data-type is inferred from the input data.

Returns
-------
out : ndarray
Array interpretation of `a`. No copy is performed if the input
is already an ndarray with matching dtype and order. If `a` is a
subclass of ndarray, a base class ndarray is returned.
"""
pass

+ 1
- 0
brainpy/backend/runners/__init__.py View File

@@ -0,0 +1 @@
# -*- coding: utf-8 -*-

+ 258
- 0
brainpy/backend/runners/general_runner.py View File

@@ -0,0 +1,258 @@
# -*- coding: utf-8 -*-

from brainpy import backend
from brainpy import errors
from brainpy.simulation import runner
from . import utils


class GeneralNodeRunner(runner.NodeRunner):
"""General BrainPy Runner for NumPy, PyTorch, TensorFlow, etc.
"""

def __init__(self, pop, steps=None):
steps = pop.steps if steps is None else pop.steps
super(GeneralNodeRunner, self).__init__(host=pop, steps=steps)
self.last_inputs = {}
self.formatted_funcs = {}
self.run_func = None

def get_input_func(self, formatted_inputs, show_code=False):
need_rebuild = False
# check whether the input is changed
# --
new_inputs = {}
input_keep_same = True
old_input_keys = list(self.last_inputs.keys())
for key, val, ops, data_type in formatted_inputs:
# set data
self.set_data(self.input_data_name(key), val)
# compare
if key in old_input_keys:
old_input_keys.remove(key)
if backend.shape(self.last_inputs[key][0]) != backend.shape(val):
input_keep_same = False
if show_code:
print(f'The current "{key}" input shape {backend.shape(val)} is different '
f'from the last input shape {backend.shape(self.last_inputs[key][0])}.')
if self.last_inputs[key][1] != ops:
input_keep_same = False
if show_code:
print(f'The current "{key}" input operation "{ops}" is different '
f'from the last operation "{self.last_inputs[key][1]}". ')
else:
input_keep_same = False
if show_code:
print(f'The input to a new key "{key}" in {self.host}.')
new_inputs[key] = (val, ops, data_type)
self.last_inputs = new_inputs
if len(old_input_keys):
input_keep_same = False
if show_code:
print(f'The inputs of {old_input_keys} in {self.host} are not provided.')

# get the function of the input
# ---
if not input_keep_same:
# codes
input_func_name = 'input_step'
host_name = self.host.name
code_scope = {host_name: self.host}
code_lines = [f'def {input_func_name}(_i):']
for key, val, ops, data_type in formatted_inputs:
if ops == '=':
line = f' {host_name}.{key} = {host_name}.{self.input_data_name(key)}'
else:
line = f' {host_name}.{key} {ops}= {host_name}.{self.input_data_name(key)}'
if data_type == 'iter':
line = line + '[_i]'
code_lines.append(line)
if len(formatted_inputs) == 0:
code_lines.append(' pass')

# function
code = '\n'.join(code_lines)
if show_code:
print(code)
print(code_scope)
print()
exec(compile(code, '', 'exec'), code_scope)
self.set_data(input_func_name, code_scope[input_func_name])
# results
self.formatted_funcs['input'] = {
'func': code_scope[input_func_name],
'scope': {host_name: self.host},
'call': [f'{host_name}.{input_func_name}(_i)'],
}
need_rebuild = True
return need_rebuild

def get_monitor_func(self, mon_length, show_code=False):
mon = self.host.mon
if len(mon['vars']) > 0:
monitor_func_name = 'monitor_step'
host = self.host.name
code_scope = {host: self.host}
code_lines = [f'def {monitor_func_name}(_i):']
for key in mon['vars']:
if not hasattr(self.host, key):
raise errors.ModelUseError(f'{self.host} do not have {key}, '
f'thus it cannot be monitored.')

# initialize monitor array #
shape = backend.shape(getattr(self.host, key))
mon[key] = backend.zeros((mon_length,) + shape)

# add line #
line = f' {host}.mon["{key}"][_i] = {host}.{key}'
code_lines.append(line)

# function
code = '\n'.join(code_lines)
if show_code:
print(code)
print(code_scope)
print()
exec(compile(code, '', 'exec'), code_scope)
self.set_data(monitor_func_name, code_scope[monitor_func_name])
# results
self.formatted_funcs['monitor'] = {
'func': code_scope[monitor_func_name],
'scope': {host: self.host},
'call': [f'{host}.{monitor_func_name}(_i)'],
}

def get_steps_func(self, show_code=False):
for func_name, step in self.steps.items():
class_args, arguments = utils.get_args(step)
host_name = self.host.name

calls = []
for arg in arguments:
if hasattr(self.host, arg):
calls.append(f'{host_name}.{arg}')
elif arg in backend.SYSTEM_KEYWORDS:
calls.append(arg)
else:
raise errors.ModelDefError(f'Step function "{func_name}" of {self.host} '
f'define an unknown argument "{arg}" which is not '
f'an attribute of {self.host} nor the system keywords '
f'{backend.SYSTEM_KEYWORDS}.')
self.formatted_funcs[func_name] = {
'func': step,
'scope': {host_name: self.host},
'call': [f'{host_name}.{func_name}({", ".join(calls)})']
}

def set_data(self, key, data):
setattr(self.host, key, data)

def build(self, formatted_inputs, mon_length, return_code=True, show_code=False):
# inputs check
# --
assert isinstance(formatted_inputs, (tuple, list))
need_rebuild = self.get_input_func(formatted_inputs, show_code=show_code)
self.formatted_funcs['need_rebuild'] = need_rebuild

# the run function does not build before
# ---
if self.run_func is None:
# monitors
self.get_monitor_func(mon_length, show_code=show_code)

# steps
self.get_steps_func(show_code=show_code)

# reshape the monitor
self.host.mon.reshape(run_length=mon_length)

# build the model
if need_rebuild or self.run_func is None:
code_scope = dict()
code_lines = ['def run_func(_t, _i, _dt):']
for process in self.get_schedule():
if (process not in self.formatted_funcs) and (process in ['input', 'monitor']):
continue
process_result = self.formatted_funcs[process]
code_scope.update(process_result['scope'])
code_lines.extend(process_result['call'])

# function
code = '\n '.join(code_lines)
if show_code:
print(code)
print(code_scope)
print()
exec(compile(code, '', 'exec'), code_scope)
self.run_func = code_scope['run_func']

if return_code:
return self.run_func, self.formatted_funcs
else:
return self.run_func

@staticmethod
def input_data_name(key):
return f'_input_data_of_{key.replace(".", "_")}'


class GeneralNetRunner(runner.NetRunner):
def __init__(self, all_nodes):
super(GeneralNetRunner, self).__init__(all_nodes=all_nodes)
self.run_func = None

def build(self, run_length, formatted_inputs, return_code=False, show_code=False):
"""Build the network.

Parameters
----------
run_length : int
The running length.
formatted_inputs : dict
The user-defined inputs.
show_code : bool
Show the formatted code.
return_code : bool
Return the code lines and code scope.

Returns
-------
step_func : callable
The step function.
"""
if not isinstance(run_length, int):
raise errors.ModelUseError(f'The running length must be an int, '
f'but we get {type(run_length)}')

# codes for step function
need_rebuild = False
code_scope = {}
code_lines = ['def run_func(_t, _i, _dt):']
for obj in self.all_nodes.values():
f, codes = obj.build(inputs=formatted_inputs.get(obj.name, []),
inputs_is_formatted=True,
mon_length=run_length,
return_code=True,
show_code=show_code)
need_rebuild *= codes['need_rebuild']
for p in obj.get_schedule():
if (p not in codes) and (p in ['input', 'monitor']):
continue
p_codes = codes[p]
code_scope.update(p_codes['scope'])
code_lines.extend(p_codes['call'])

# compile the step function
if (self.run_func is None) or need_rebuild:
code = '\n '.join(code_lines)
if show_code:
print(code)
print(code_scope)
print()
exec(compile(code, '', 'exec'), code_scope)
self.run_func = code_scope['run_func']

if return_code:
return self.run_func, code_lines, code_scope
else:
return self.run_func

+ 5
- 0
brainpy/backend/runners/jax_runner.py View File

@@ -0,0 +1,5 @@
# -*- coding: utf-8 -*-


class JaxRunner(object):
pass

+ 583
- 0
brainpy/backend/runners/numba_cpu_runner.py View File

@@ -0,0 +1,583 @@
# -*- coding: utf-8 -*-

import ast
import inspect
import re

import numba

from brainpy import backend
from brainpy import errors
from brainpy import tools
from brainpy.simulation import delay
from . import utils
from .general_runner import GeneralNodeRunner

__all__ = [
'set_numba_profile',
'get_numba_profile',

'StepFuncReader',
'analyze_step_func',
'get_func_body_code',
'get_num_indent',

'NumbaCPUNodeRunner',
]

NUMBA_PROFILE = {
'nopython': True,
'fastmath': True,
'nogil': True,
'parallel': False
}


def set_numba_profile(**kwargs):
"""Set the compilation options of Numba JIT function.

Parameters
----------
kwargs : Any
The arguments, including ``cache``, ``fastmath``,
``parallel``, ``nopython``.
"""
global NUMBA_PROFILE

if 'fastmath' in kwargs:
NUMBA_PROFILE['fastmath'] = kwargs.pop('fastmath')
if 'nopython' in kwargs:
NUMBA_PROFILE['nopython'] = kwargs.pop('nopython')
if 'nogil' in kwargs:
NUMBA_PROFILE['nogil'] = kwargs.pop('nogil')
if 'parallel' in kwargs:
NUMBA_PROFILE['parallel'] = kwargs.pop('parallel')


def get_numba_profile():
"""Get the compilation setting of numba JIT function.

Returns
-------
numba_setting : dict
Numba setting.
"""
return NUMBA_PROFILE


class StepFuncReader(ast.NodeVisitor):
def __init__(self, host):
self.lefts = []
self.rights = []
self.lines = []
self.visited_nodes = set()

self.host = host
# get delay information
self.delay_call = {}

def visit_Assign(self, node, level=0):
if node not in self.visited_nodes:
prefix = ' ' * level
expr = tools.ast2code(ast.fix_missing_locations(node.value))
targets = []
for target in node.targets:
targets.append(tools.ast2code(ast.fix_missing_locations(target)))
_target = ' = '.join(targets)

self.rights.append(expr)
self.lefts.append(_target)
self.lines.append(f'{prefix}{_target} = {expr}')

self.visited_nodes.add(node)

self.generic_visit(node)

def visit_AugAssign(self, node, level=0):
if node not in self.visited_nodes:
prefix = ' ' * level
op = tools.ast2code(ast.fix_missing_locations(node.op))
expr = tools.ast2code(ast.fix_missing_locations(node.value))
target = tools.ast2code(ast.fix_missing_locations(node.target))

self.lefts.append(target)
self.rights.append(f'{target} {op} {expr}')
self.lines.append(f"{prefix}{target} = {target} {op} {expr}")

self.visited_nodes.add(node)

self.generic_visit(node)

def visit_AnnAssign(self, node):
raise NotImplementedError('Do not support an assignment with '
'a type annotation in Numba backend.')

def visit_node_not_assign(self, node, level=0):
if node not in self.visited_nodes:
prefix = ' ' * level
expr = tools.ast2code(ast.fix_missing_locations(node))
self.lines.append(f'{prefix}{expr}')
self.lefts.append('')
self.rights.append(expr)
self.visited_nodes.add(node)

self.generic_visit(node)

def visit_Assert(self, node, level=0):
self.visit_node_not_assign(node, level)

def visit_Expr(self, node, level=0):
self.visit_node_not_assign(node, level)

def visit_Expression(self, node, level=0):
self.visit_node_not_assign(node, level)

def visit_content_in_condition_control(self, node, level):
if isinstance(node, ast.Expr):
self.visit_Expr(node, level)
elif isinstance(node, ast.Assert):
self.visit_Assert(node, level)
elif isinstance(node, ast.Assign):
self.visit_Assign(node, level)
elif isinstance(node, ast.AugAssign):
self.visit_AugAssign(node, level)
elif isinstance(node, ast.If):
self.visit_If(node, level)
elif isinstance(node, ast.For):
self.visit_For(node, level)
elif isinstance(node, ast.While):
self.visit_While(node, level)
elif isinstance(node, ast.Call):
self.visit_Call(node, level)
elif isinstance(node, ast.Raise):
self.visit_Raise(node, level)
else:
code = tools.ast2code(ast.fix_missing_locations(node))
raise errors.CodeError(f'BrainPy does not support {type(node)} '
f'in Numba backend.\n\n{code}')

def visit_attr(self, node):
if isinstance(node, ast.Attribute):
r = self.visit_attr(node.value)
return [node.attr] + r
elif isinstance(node, ast.Name):
return [node.id]
else:
raise ValueError

def visit_Call(self, node, level=0):
if node in self.delay_call:
return
calls = self.visit_attr(node.func)
calls = calls[::-1]

# delay push / delay pull
if calls[-1] in ['push', 'pull']:
obj = self.host
for data in calls[1:-1]:
obj = getattr(obj, data)
obj_func = getattr(obj, calls[-1])
if isinstance(obj, delay.ConstantDelay) and callable(obj_func):
func = ".".join(calls)
args = []
for arg in node.args:
args.append(tools.ast2code(ast.fix_missing_locations(arg)))
keywords = []
for arg in node.keywords:
keywords.append(tools.ast2code(ast.fix_missing_locations(arg)))
delay_var = '.'.join([self.host.name] + calls[1:-1])
if calls[-1] == 'push':
kws_append = [f'delay_data={delay_var}_delay_data',
f'delay_in_idx={delay_var}_delay_in_idx', ]
data_need_pass = [f'{self.host.name}.{".".join(calls[1:-1])}.delay_data',
f'{self.host.name}.{".".join(calls[1:-1])}.delay_in_idx']
else:
kws_append = [f'delay_data={delay_var}_delay_data',
f'delay_out_idx={delay_var}_delay_out_idx', ]
data_need_pass = [f'{self.host.name}.{".".join(calls[1:-1])}.delay_data',
f'{self.host.name}.{".".join(calls[1:-1])}.delay_out_idx']
org_call = tools.ast2code(ast.fix_missing_locations(node))
rep_call = f'{func}({", ".join(args + keywords + kws_append)})'
self.delay_call[node] = dict(type=calls[-1],
args=args,
keywords=keywords,
kws_append=kws_append,
func=func,
org_call=org_call,
rep_call=rep_call,
data_need_pass=data_need_pass)

self.generic_visit(node)

def visit_If(self, node, level=0):
if node not in self.visited_nodes:
# If condition
prefix = ' ' * level
compare = tools.ast2code(ast.fix_missing_locations(node.test))
self.rights.append(f'if {compare}:')
self.lines.append(f'{prefix}if {compare}:')

# body
for expr in node.body:
self.visit_content_in_condition_control(expr, level + 1)

# elif
while node.orelse and len(node.orelse) == 1 and isinstance(node.orelse[0], ast.If):
node = node.orelse[0]
compare = tools.ast2code(ast.fix_missing_locations(node.test))
self.lines.append(f'{prefix}elif {compare}:')
for expr in node.body:
self.visit_content_in_condition_control(expr, level + 1)

# else:
if len(node.orelse) > 0:
self.lines.append(f'{prefix}else:')
for expr in node.orelse:
self.visit_content_in_condition_control(expr, level + 1)

self.visited_nodes.add(node)

self.generic_visit(node)

def visit_For(self, node, level=0):
if node not in self.visited_nodes:
prefix = ' ' * level
# target
target = tools.ast2code(ast.fix_missing_locations(node.target))
# iter
iter = tools.ast2code(ast.fix_missing_locations(node.iter))
self.rights.append(f'{target} in {iter}')
self.lines.append(prefix + f'for {target} in {iter}:')
# body
for expr in node.body:
self.visit_content_in_condition_control(expr, level + 1)
# else
if len(node.orelse) > 0:
self.lines.append(prefix + 'else:')
for expr in node.orelse:
self.visit_content_in_condition_control(expr, level + 1)

self.visited_nodes.add(node)
self.generic_visit(node)

def visit_While(self, node, level=0):
if node not in self.visited_nodes:
prefix = ' ' * level
# test
test = tools.ast2code(ast.fix_missing_locations(node.test))
self.rights.append(test)
self.lines.append(prefix + f'while {test}:')
# body
for expr in node.body:
self.visit_content_in_condition_control(expr, level + 1)
# else
if len(node.orelse) > 0:
self.lines.append(prefix + 'else:')
for expr in node.orelse:
self.visit_content_in_condition_control(expr, level + 1)

self.visited_nodes.add(node)
self.generic_visit(node)

def visit_Raise(self, node, level=0):
if node not in self.visited_nodes:
prefix = ' ' * level
line = tools.ast2code(ast.fix_missing_locations(node))
self.lines.append(prefix + line)

self.visited_nodes.add(node)
self.generic_visit(node)

def visit_Try(self, node):
raise errors.CodeError('Do not support "try" handler in Numba backend.')

def visit_With(self, node):
raise errors.CodeError('Do not support "with" block in Numba backend.')

def visit_Delete(self, node):
raise errors.CodeError('Do not support "del" operation in Numba backend.')


def analyze_step_func(host, f):
"""Analyze the step functions in a population.

Parameters
----------
f : callable
The step function.
host : Population
The data and the function host.

Returns
-------
results : dict
The code string of the function, the code scope,
the data need pass into the arguments,
the data need return.
"""
code_string = tools.deindent(inspect.getsource(f)).strip()
tree = ast.parse(code_string)

# arguments
# ---
args = tools.ast2code(ast.fix_missing_locations(tree.body[0].args)).split(',')

# code AST analysis
# ---
formatter = StepFuncReader(host=host)
formatter.visit(tree)

# data assigned by self.xx in line right
# ---
self_data_in_right = []
if args[0] in backend.CLASS_KEYWORDS:
code = ', \n'.join(formatter.rights)
self_data_in_right = re.findall('\\b' + args[0] + '\\.[A-Za-z_][A-Za-z0-9_.]*\\b', code)
self_data_in_right = list(set(self_data_in_right))

# data assigned by self.xxx in line left
# ---
code = ', \n'.join(formatter.lefts)
self_data_without_index_in_left = []
self_data_with_index_in_left = []
if args[0] in backend.CLASS_KEYWORDS:
class_p1 = '\\b' + args[0] + '\\.[A-Za-z_][A-Za-z0-9_.]*\\b'
self_data_without_index_in_left = set(re.findall(class_p1, code))
class_p2 = '(\\b' + args[0] + '\\.[A-Za-z_][A-Za-z0-9_.]*)\\[.*\\]'
self_data_with_index_in_left = set(re.findall(class_p2, code)) #- self_data_without_index_in_left
# self_data_with_index_in_left = set(re.findall(class_p2, code)) - self_data_without_index_in_left
self_data_with_index_in_left = list(self_data_with_index_in_left)
self_data_without_index_in_left = list(self_data_without_index_in_left)

# code scope
# ---
closure_vars = inspect.getclosurevars(f)
code_scope = dict(closure_vars.nonlocals)
code_scope.update(closure_vars.globals)

# final
# ---
self_data_in_right = sorted(self_data_in_right)
self_data_without_index_in_left = sorted(self_data_without_index_in_left)
self_data_with_index_in_left = sorted(self_data_with_index_in_left)

analyzed_results = {
'delay_call': formatter.delay_call,
'code_string': '\n'.join(formatter.lines),
'code_scope': code_scope,
'self_data_in_right': self_data_in_right,
'self_data_without_index_in_left': self_data_without_index_in_left,
'self_data_with_index_in_left': self_data_with_index_in_left,
}

return analyzed_results


def get_func_body_code(code_string, lambda_func=False):
"""Get the main body code of a function.

Parameters
----------
code_string : str
The code string of the function.
lambda_func : bool
Whether the code comes from a lambda function.

Returns
-------
code_body : str
The code body.
"""
if lambda_func:
splits = code_string.split(':')
if len(splits) != 2:
raise ValueError(f'Can not parse function: \n{code_string}')
main_code = f'return {splits[1]}'
else:
func_codes = code_string.split('\n')
idx = 0
for i, line in enumerate(func_codes):
idx += 1
line = line.replace(' ', '')
if '):' in line:
break
else:
raise ValueError(f'Can not parse function: \n{code_string}')
main_code = '\n'.join(func_codes[idx:])
return main_code


def get_num_indent(code_string, spaces_per_tab=4):
"""Get the indent of a patch of source code.

Parameters
----------
code_string : str
The code string.
spaces_per_tab : int
The spaces per tab.

Returns
-------
num_indent : int
The number of the indent.
"""
lines = code_string.split('\n')
min_indent = 1000
for line in lines:
if line.strip() == '':
continue
line = line.replace('\t', ' ' * spaces_per_tab)
num_indent = len(line) - len(line.lstrip())
if num_indent < min_indent:
min_indent = num_indent
return min_indent


def class2func(cls_func, host, func_name=None, show_code=False):
"""Transform the function in a class into the ordinary function which is
compatible with the Numba JIT compilation.

Parameters
----------
cls_func : function
The function of the instantiated class.
func_name : str
The function name. If not given, it will get the function by `cls_func.__name__`.
show_code : bool
Whether show the code.

Returns
-------
new_func : function
The transformed function.
"""
class_arg, arguments = utils.get_args(cls_func)
func_name = cls_func.__name__ if func_name is None else func_name
host_name = host.name

# arguments 1
calls = []
for arg in arguments:
if hasattr(host, arg):
calls.append(f'{host_name}.{arg}')
elif arg in backend.SYSTEM_KEYWORDS:
calls.append(arg)
else:
raise errors.ModelDefError(f'Step function "{func_name}" of {host} '
f'define an unknown argument "{arg}" which is not '
f'an attribute of {host} nor the system keywords '
f'{backend.SYSTEM_KEYWORDS}.')

# analysis
analyzed_results = analyze_step_func(host=host, f=cls_func)
delay_call = analyzed_results['delay_call']
# code_string = analyzed_results['code_string']
main_code = analyzed_results['code_string']
code_scope = analyzed_results['code_scope']
self_data_in_right = analyzed_results['self_data_in_right']
self_data_without_index_in_left = analyzed_results['self_data_without_index_in_left']
self_data_with_index_in_left = analyzed_results['self_data_with_index_in_left']
# main_code = get_func_body_code(code_string)
num_indent = get_num_indent(main_code)
data_need_pass = sorted(list(set(self_data_in_right + self_data_with_index_in_left)))
data_need_return = self_data_without_index_in_left

# check delay
replaces_early = {}
replaces_later = {}
if len(delay_call) > 0:
for delay_ in delay_call.values():
# delay_ = dict(type=calls[-1],
# args=args,
# keywords=keywords,
# kws_append=kws_append,
# func=func,
# org_call=org_call,
# rep_call=rep_call,
# data_need_pass=data_need_pass)
if delay_['type'] == 'push':
if len(delay_['args'] + delay_['keywords']) == 2:
func = numba.njit(delay.push_type2)
elif len(delay_['args'] + delay_['keywords']) == 1:
func = numba.njit(delay.push_type1)
else:
raise ValueError(f'Unknown delay push. {delay_}')
else:
if len(delay_['args'] + delay_['keywords']) == 1:
func = numba.njit(delay.pull_type1)
elif len(delay_['args'] + delay_['keywords']) == 0:
func = numba.njit(delay.pull_type0)
else:
raise ValueError(f'Unknown delay pull. {delay_}')
delay_call_name = delay_['func']
data_need_pass.remove(delay_call_name)
data_need_pass.extend(delay_['data_need_pass'])
replaces_early[delay_['org_call']] = delay_['rep_call']
replaces_later[delay_call_name] = delay_call_name.replace('.', '_')
code_scope[delay_call_name.replace('.', '_')] = func
for target, dest in replaces_early.items():
main_code = main_code.replace(target, dest)
# main_code = tools.word_replace(main_code, replaces_early)

# arguments 2: data need pass
new_args = arguments + []
for data in sorted(set(data_need_pass)):
splits = data.split('.')
replaces_later[data] = data.replace('.', '_')
obj = host
for attr in splits[1:]:
obj = getattr(obj, attr)
if callable(obj):
code_scope[data.replace('.', '_')] = obj
continue
new_args.append(data.replace('.', '_'))
calls.append('.'.join([host_name] + splits[1:]))

# data need return
assigns = []
returns = []
for data in data_need_return:
splits = data.split('.')
assigns.append('.'.join([host_name] + splits[1:]))
returns.append(data.replace('.', '_'))
replaces_later[data] = data.replace('.', '_')

# code scope
code_scope[host_name] = host

# codes
header = f'def new_{func_name}({", ".join(new_args)}):\n'
main_code = header + tools.indent(main_code, spaces_per_tab=2)
if len(returns):
main_code += f'\n{" " * num_indent + " "}return {", ".join(returns)}'
main_code = tools.word_replace(main_code, replaces_later)
if show_code:
print(main_code)
print(code_scope)
print()

# recompile
exec(compile(main_code, '', 'exec'), code_scope)
func = code_scope[f'new_{func_name}']
func = numba.jit(**NUMBA_PROFILE)(func)
return func, calls, assigns


class NumbaCPUNodeRunner(GeneralNodeRunner):
def get_steps_func(self, show_code=False):
for func_name, step in self.steps.items():
host = step.__self__
func, calls, assigns = class2func(cls_func=step, host=host, func_name=func_name, show_code=show_code)
# self.set_data(f'new_{func_name}', func)
setattr(host, f'new_{func_name}', func)

# finale
assignment_line = ''
if len(assigns):
assignment_line = f'{", ".join(assigns)} = '
self.formatted_funcs[func_name] = {
'func': func,
'scope': {host.name: host},
'call': [f'{assignment_line}{host.name}.new_{func_name}({", ".join(calls)})']
}

+ 158
- 0
brainpy/backend/runners/numba_cuda_runner.py View File

@@ -0,0 +1,158 @@
# -*- coding: utf-8 -*-

import ast

from brainpy import backend
from brainpy import tools
from brainpy.simulation.brain_objects import SynConn, NeuGroup
from .numba_cpu_runner import NumbaCPUNodeRunner
from .numba_cpu_runner import StepFuncReader

__all__ = [
'NumbaCudaNodeRunner',
]


class CudaStepFuncReader(StepFuncReader):
def __init__(self, host):
super(CudaStepFuncReader, self).__init__(host=host)

self.need_add_cuda = False
# get pre assignment
self.pre_assign = []
# get post assignment
self.post_assign = []

def check_atomic_ops(self, target):
if isinstance(self.host, SynConn) and isinstance(target, ast.Subscript):
values = self.visit_attr(target.value)
slice_ = tools.ast2code(ast.fix_missing_locations(target.slice))
if len(values) >= 3 and values[-1] in backend.CLASS_KEYWORDS:
obj = getattr(self.host, values[-2])
if isinstance(obj, NeuGroup):
target_ = '.'.join(values[::-1])
return target_, slice_
return None

def visit_Assign(self, node, level=0):
self.generic_visit(node)
prefix = ' ' * level
expr = tools.ast2code(ast.fix_missing_locations(node.value))
self.rights.append(expr)

check = None
if len(node.targets) == 1:
check = self.check_atomic_ops(node.targets[0])

if check is None:
targets = []
for target in node.targets:
targets.append(tools.ast2code(ast.fix_missing_locations(target)))
_target = ' = '.join(targets)
self.lefts.append(_target)
self.lines.append(f'{prefix}{_target} = {expr}')
else:
target, slice_ = check
self.lefts.append(target)
self.lines.append(f'{prefix}cuda.atomic.add({target}, {slice_}, {expr})')

def visit_AugAssign(self, node, level=0):
self.generic_visit(node)
prefix = ' ' * level
op = tools.ast2code(ast.fix_missing_locations(node.op))
expr = tools.ast2code(ast.fix_missing_locations(node.value))

check = self.check_atomic_ops(node.target)
if check is None:
target = tools.ast2code(ast.fix_missing_locations(node.target))
self.lefts.append(target)
self.rights.append(expr)
self.lines.append(f"{prefix}{target} {op}= {expr}")
else:
if op == '+':
expr = expr
elif op == '-':
expr = '-' + expr
else:
raise ValueError
target, slice_ = check
self.lefts.append(target)
self.lines.append(f'{prefix}cuda.atomic.add({target}, {slice_}, {expr})')


def analyze_step_func(f, host):
"""Analyze the step functions in a population.

Parameters
----------
f : callable
The step function.
host : Population
The data and the function host.

Returns
-------
results : dict
The code string of the function, the code scope,
the data need pass into the arguments,
the data need return.
"""

code_string = tools.deindent(inspect.getsource(f)).strip()
tree = ast.parse(code_string)

# arguments
# ---
args = tools.ast2code(ast.fix_missing_locations(tree.body[0].args)).split(',')

# code lines
# ---
formatter = StepFuncReader(host=host)
formatter.visit(tree)

# data assigned by self.xx in line right
# ---
self_data_in_right = []
if args[0] in backend.CLASS_KEYWORDS:
code = ', \n'.join(formatter.rights)
self_data_in_right = re.findall('\\b' + args[0] + '\\.[A-Za-z_][A-Za-z0-9_.]*\\b', code)
self_data_in_right = list(set(self_data_in_right))

# data assigned by self.xxx in line left
# ---
code = ', \n'.join(formatter.lefts)
self_data_without_index_in_left = []
self_data_with_index_in_left = []
if args[0] in backend.CLASS_KEYWORDS:
class_p1 = '\\b' + args[0] + '\\.[A-Za-z_][A-Za-z0-9_.]*\\b'
self_data_without_index_in_left = set(re.findall(class_p1, code))
class_p2 = '(\\b' + args[0] + '\\.[A-Za-z_][A-Za-z0-9_.]*)\\[.*\\]'
self_data_with_index_in_left = set(re.findall(class_p2, code))
self_data_without_index_in_left -= self_data_with_index_in_left
self_data_without_index_in_left = list(self_data_without_index_in_left)

# code scope
# ---
closure_vars = inspect.getclosurevars(f)
code_scope = dict(closure_vars.nonlocals)
code_scope.update(closure_vars.globals)

# final
# ---
self_data_in_right = sorted(self_data_in_right)
self_data_without_index_in_left = sorted(self_data_without_index_in_left)
self_data_with_index_in_left = sorted(self_data_with_index_in_left)

analyzed_results = {
'code_string': code_string,
'code_scope': code_scope,
'self_data_in_right': self_data_in_right,
'self_data_without_index_in_left': self_data_without_index_in_left,
'self_data_with_index_in_left': self_data_with_index_in_left,
}

return analyzed_results


class NumbaCudaNodeRunner(NumbaCPUNodeRunner):
pass

+ 57
- 0
brainpy/backend/runners/utils.py View File

@@ -0,0 +1,57 @@
# -*- coding: utf-8 -*-


import inspect

from brainpy import backend
from brainpy import errors

__all__ = [
'get_args'
]


def get_args(f):
"""Get the function arguments.

Parameters
----------
f : callable
The function.

Returns
-------
args : tuple
The variable names, the other arguments, and the original args.
"""

# 1. get the function arguments
parameters = inspect.signature(f).parameters

arguments = []
for name, par in parameters.items():
if par.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD:
arguments.append(par.name)

elif par.kind is inspect.Parameter.KEYWORD_ONLY:
arguments.append(par.name)

elif par.kind is inspect.Parameter.VAR_POSITIONAL:
raise errors.ModelDefError('Step function do not support positional parameters, e.g., *args')
elif par.kind is inspect.Parameter.POSITIONAL_ONLY:
raise errors.ModelDefError('Step function do not support positional only parameters, e.g., /')
elif par.kind is inspect.Parameter.VAR_KEYWORD:
raise errors.ModelDefError(f'Step function do not support dict of keyword arguments: {str(par)}')
else:
raise errors.ModelDefError(f'Unknown argument type: {par.kind}')

# 2. check the function arguments
class_kw = None
if len(arguments) > 0 and arguments[0] in backend.CLASS_KEYWORDS:
class_kw = arguments[0]
arguments = arguments[1:]
for a in arguments:
if a in backend.CLASS_KEYWORDS:
raise errors.DiffEqError(f'Class keywords "{a}" must be defined '
f'as the first argument.')
return class_kw, arguments

+ 0
- 44
brainpy/backend/utils.py View File

@@ -1,44 +0,0 @@
# -*- coding: utf-8 -*-


import math

import numba
import numpy

from .. import profile

__all__ = [
'func_in_numpy_or_math',
'normal_like',
]

# Get functions in math
_functions_in_math = []
for key in dir(math):
if not key.startswith('__'):
_functions_in_math.append(getattr(math, key))

# Get functions in NumPy
_functions_in_numpy = []
for key in dir(numpy):
if not key.startswith('__'):
_functions_in_numpy.append(getattr(numpy, key))
for key in dir(numpy.random):
if not key.startswith('__'):
_functions_in_numpy.append(getattr(numpy.random, key))
for key in dir(numpy.linalg):
if not key.startswith('__'):
_functions_in_numpy.append(getattr(numpy.linalg, key))


def func_in_numpy_or_math(func):
return func in _functions_in_math or func in _functions_in_numpy


@numba.generated_jit(**profile.get_numba_profile())
def normal_like(x):
if isinstance(x, (numba.types.Integer, numba.types.Float)):
return lambda x: numpy.random.normal()
else:
return lambda x: numpy.random.normal(0., 1.0, x.shape)

+ 98
- 85
brainpy/connectivity/base.py View File

@@ -1,13 +1,16 @@
# -*- coding: utf-8 -*-

import numba as nb
import numpy as np
import abc

from .. import profile
from ..errors import ModelUseError
from brainpy import backend
from brainpy import errors

try:
import numba as nb
except ModuleNotFoundError:
nb = None

__all__ = [
'Connector',
'ij2mat',
'mat2ij',
'pre2post',
@@ -16,9 +19,19 @@ __all__ = [
'post2syn',
'pre_slice_syn',
'post_slice_syn',

'AbstractConnector',
'Connector',
]


def _numba_backend():
r = backend.get_backend().startswith('numba')
if r and nb is None:
raise errors.PackageMissingError('Please install numba for numba backend.')
return r


def ij2mat(i, j, num_pre=None, num_post=None):
"""Convert i-j connection to matrix connection.

@@ -39,17 +52,14 @@ def ij2mat(i, j, num_pre=None, num_post=None):
A 2D ndarray connectivity matrix.
"""
if len(i) != len(j):
raise ModelUseError('"i" and "j" must be the equal length.')
raise errors.ModelUseError('"i" and "j" must be the equal length.')
if num_pre is None:
print('WARNING: "num_pre" is not provided, the result may not be accurate.')
num_pre = np.max(i)
num_pre = i.max()
if num_post is None:
print('WARNING: "num_post" is not provided, the result may not be accurate.')
num_post = np.max(j)

i = np.asarray(i, dtype=np.int64)
j = np.asarray(j, dtype=np.int64)
conn_mat = np.zeros((num_pre, num_post), dtype=np.float_)
num_post = j.max()
conn_mat = backend.zeros((num_pre, num_post))
conn_mat[i, j] = 1.
return conn_mat

@@ -61,20 +71,18 @@ def mat2ij(conn_mat):
----------
conn_mat : np.ndarray
Connectivity matrix with `(num_pre, num_post)` shape.
Returns
-------
conn_tuple : tuple
(Pre-synaptic neuron indexes,
post-synaptic neuron indexes).
"""
conn_mat = np.asarray(conn_mat)
if np.ndim(conn_mat) != 2:
raise ModelUseError('Connectivity matrix must be in the shape of (num_pre, num_post).')
pre_ids, post_ids = np.where(conn_mat > 0)
pre_ids = np.ascontiguousarray(pre_ids, dtype=np.int_)
post_ids = np.ascontiguousarray(post_ids, dtype=np.int_)
return pre_ids, post_ids
if len(backend.shape(conn_mat)) != 2:
raise errors.ModelUseError('Connectivity matrix must be in the '
'shape of (num_pre, num_post).')
pre_ids, post_ids = backend.where(conn_mat > 0)
return backend.as_tensor(pre_ids), backend.as_tensor(post_ids)


def pre2post(i, j, num_pre=None):
@@ -95,20 +103,20 @@ def pre2post(i, j, num_pre=None):
The conn list of pre2post.
"""
if len(i) != len(j):
raise ModelUseError('The length of "i" and "j" must be the same.')
raise errors.ModelUseError('The length of "i" and "j" must be the same.')
if num_pre is None:
print('WARNING: "num_pre" is not provided, the result may not be accurate.')
num_pre = np.max(i)
num_pre = i.max()

pre2post_list = [[] for _ in range(num_pre)]
for pre_id, post_id in zip(i, j):
pre2post_list[pre_id].append(post_id)
pre2post_list = [np.array(l) for l in pre2post_list]
pre2post_list = [backend.as_tensor(l) for l in pre2post_list]

if profile.is_jit():
if _numba_backend:
pre2post_list_nb = nb.typed.List()
for pre_id in range(num_pre):
pre2post_list_nb.append(np.int64(pre2post_list[pre_id]))
pre2post_list_nb.append(pre2post_list[pre_id])
pre2post_list = pre2post_list_nb
return pre2post_list

@@ -132,20 +140,20 @@ def post2pre(i, j, num_post=None):
"""

if len(i) != len(j):
raise ModelUseError('The length of "i" and "j" must be the same.')
raise errors.ModelUseError('The length of "i" and "j" must be the same.')
if num_post is None:
print('WARNING: "num_post" is not provided, the result may not be accurate.')
num_post = np.max(j)
num_post = j.max()

post2pre_list = [[] for _ in range(num_post)]
for pre_id, post_id in zip(i, j):
post2pre_list[post_id].append(pre_id)
post2pre_list = [np.array(l) for l in post2pre_list]
post2pre_list = [backend.as_tensor(l) for l in post2pre_list]

if profile.is_jit():
if _numba_backend():
post2pre_list_nb = nb.typed.List()
for post_id in range(num_post):
post2pre_list_nb.append(np.int64(post2pre_list[post_id]))
post2pre_list_nb.append(post2pre_list[post_id])
post2pre_list = post2pre_list_nb
return post2pre_list

@@ -167,17 +175,17 @@ def pre2syn(i, num_pre=None):
"""
if num_pre is None:
print('WARNING: "num_pre" is not provided, the result may not be accurate.')
num_pre = np.max(i)
num_pre = i.max()

pre2syn_list = [[] for _ in range(num_pre)]
for syn_id, pre_id in enumerate(i):
pre2syn_list[pre_id].append(syn_id)
pre2syn_list = [np.array(l) for l in pre2syn_list]
pre2syn_list = [backend.as_tensor(l) for l in pre2syn_list]

if profile.is_jit():
if _numba_backend():
pre2syn_list_nb = nb.typed.List()
for pre_ids in pre2syn_list:
pre2syn_list_nb.append(np.int64(pre_ids))
pre2syn_list_nb.append(pre_ids)
pre2syn_list = pre2syn_list_nb

return pre2syn_list
@@ -200,17 +208,17 @@ def post2syn(j, num_post=None):
"""
if num_post is None:
print('WARNING: "num_post" is not provided, the result may not be accurate.')
num_post = np.max(j)
num_post = j.max()

post2syn_list = [[] for _ in range(num_post)]
for syn_id, post_id in enumerate(j):
post2syn_list[post_id].append(syn_id)
post2syn_list = [np.array(l) for l in post2syn_list]
post2syn_list = [backend.as_tensor(l) for l in post2syn_list]

if profile.is_jit():
if _numba_backend():
post2syn_list_nb = nb.typed.List()
for pre_ids in post2syn_list:
post2syn_list_nb.append(np.int64(pre_ids))
post2syn_list_nb.append(pre_ids)
post2syn_list = post2syn_list_nb

return post2syn_list
@@ -235,16 +243,21 @@ def pre_slice_syn(i, j, num_pre=None):
"""
# check
if len(i) != len(j):
raise ModelUseError('The length of "i" and "j" must be the same.')
raise errors.ModelUseError('The length of "i" and "j" must be the same.')
if num_pre is None:
print('WARNING: "num_pre" is not provided, the result may not be accurate.')
num_pre = np.max(i)
num_pre = i.max()

# pre2post connection
pre2post_list = [[] for _ in range(num_pre)]
for pre_id, post_id in zip(i, j):
pre2post_list[pre_id].append(post_id)
post_ids = np.asarray(np.concatenate(pre2post_list), dtype=np.int_)
pre_ids, post_ids = [], []
for pre_i, posts in enumerate(pre2post_list):
post_ids.extend(posts)
pre_ids.extend([pre_i] * len(posts))
post_ids = backend.as_tensor(post_ids)
pre_ids = backend.as_tensor(pre_ids)

# pre2post slicing
slicing = []
@@ -253,10 +266,7 @@ def pre_slice_syn(i, j, num_pre=None):
end = start + len(posts)
slicing.append([start, end])
start = end
slicing = np.asarray(slicing, dtype=np.int_)

# pre_ids
pre_ids = np.repeat(np.arange(num_pre), slicing[:, 1] - slicing[:, 0])
slicing = backend.as_tensor(slicing)

return pre_ids, post_ids, slicing

@@ -279,16 +289,21 @@ def post_slice_syn(i, j, num_post=None):
The conn list of post2syn.
"""
if len(i) != len(j):
raise ModelUseError('The length of "i" and "j" must be the same.')
raise errors.ModelUseError('The length of "i" and "j" must be the same.')
if num_post is None:
print('WARNING: "num_post" is not provided, the result may not be accurate.')
num_post = np.max(j)
num_post = j.max()

# post2pre connection
post2pre_list = [[] for _ in range(num_post)]
for pre_id, post_id in zip(i, j):
post2pre_list[post_id].append(pre_id)
pre_ids = np.asarray(np.concatenate(post2pre_list), dtype=np.int_)
pre_ids, post_ids = [], []
for _post_id, _pre_ids in enumerate(post2pre_list):
pre_ids.extend(_pre_ids)
post_ids.extend([_post_id] * len(_pre_ids))
post_ids = backend.as_tensor(post_ids)
pre_ids = backend.as_tensor(pre_ids)

# post2pre slicing
slicing = []
@@ -297,15 +312,23 @@ def post_slice_syn(i, j, num_post=None):
end = start + len(pres)
slicing.append([start, end])
start = end
slicing = np.asarray(slicing, dtype=np.int_)

# post_ids
post_ids = np.repeat(np.arange(num_post), slicing[:, 1] - slicing[:, 0])
slicing = backend.as_tensor(slicing)

return pre_ids, post_ids, slicing


class Connector(object):
SUPPORTED_SYN_STRUCTURE = ['pre_ids', 'post_ids', 'conn_mat',
'pre2post', 'post2pre',
'pre2syn', 'post2syn',
'pre_slice_syn', 'post_slice_syn']


class AbstractConnector(abc.ABC):
def __call__(self, *args, **kwargs):
pass


class Connector(AbstractConnector):
"""Abstract connector class."""

def __init__(self):
@@ -313,6 +336,7 @@ class Connector(object):
# useful for the construction of pre2post/pre2syn/etc.
self.num_pre = None
self.num_post = None

# synaptic structures
self.pre_ids = None
self.post_ids = None
@@ -323,47 +347,33 @@ class Connector(object):
self.post2syn = None
self.pre_slice_syn = None
self.post_slice_syn = None

# synaptic weights
self.weights = None
# the required synaptic structures
self.requires = ()

def set_size(self, num_pre, num_post):
try:
assert isinstance(num_pre, int)
assert 0 < num_pre
except AssertionError:
raise ModelUseError('"num_pre" must be integrator bigger than 0.')
try:
assert isinstance(num_post, int)
assert 0 < num_post
except AssertionError:
raise ModelUseError('"num_post" must be integrator bigger than 0.')
self.num_pre = num_pre
self.num_post = num_post

def set_requires(self, syn_requires):

def requires(self, *syn_requires):
# get synaptic requires
requires = set()
for n in syn_requires:
if n in ['pre_ids', 'post_ids', 'conn_mat',
'pre2post', 'post2pre',
'pre2syn', 'post2syn',
'pre_slice_syn', 'post_slice_syn']:
if n in SUPPORTED_SYN_STRUCTURE:
requires.add(n)
self.requires = list(requires)
else:
raise ValueError(f'Unknown synapse structure {n}. We only support '
f'{SUPPORTED_SYN_STRUCTURE}.')
requires = list(requires)

# synaptic structure to handle
needs = []
if 'pre_slice_syn' in self.requires and 'post_slice_syn' in self.requires:
raise ModelUseError('Cannot use "pre_slice_syn" and "post_slice_syn" simultaneously. \n'
'We recommend you use "pre_slice_syn + post2syn" '
'or "post_slice_syn + pre2syn".')
elif 'pre_slice_syn' in self.requires:
if 'pre_slice_syn' in requires and 'post_slice_syn' in requires:
raise errors.ModelUseError('Cannot use "pre_slice_syn" and "post_slice_syn" '
'simultaneously. \n'
'We recommend you use "pre_slice_syn + '
'post2syn" or "post_slice_syn + pre2syn".')
elif 'pre_slice_syn' in requires:
needs.append('pre_slice_syn')
elif 'post_slice_syn' in self.requires:
elif 'post_slice_syn' in requires:
needs.append('post_slice_syn')
for n in self.requires:
for n in requires:
if n in ['pre_slice_syn', 'post_slice_syn', 'pre_ids', 'post_ids']:
continue
needs.append(n)
@@ -372,8 +382,11 @@ class Connector(object):
for n in needs:
getattr(self, f'make_{n}')()

def __call__(self, pre_indices, post_indices):
raise NotImplementedError
# returns
if len(requires) == 1:
return getattr(self, requires[0])
else:
return tuple([getattr(self, r) for r in requires])

def make_conn_mat(self):
if self.conn_mat is None:


+ 324
- 334
brainpy/connectivity/methods.py View File

@@ -1,55 +1,233 @@
# -*- coding: utf-8 -*-

import numba as nb
import numpy as np

from . import base
from .. import errors
from brainpy import backend
from brainpy import errors
from .base import Connector

try:
import numba as nb
except ModuleNotFoundError:
nb = None

__all__ = [
'One2One', 'one2one',
'All2All', 'all2all',
'GridFour', 'grid_four',
'GridEight', 'grid_eight',
'GridN',
'FixedPostNum',
'FixedPreNum',
'FixedProb',
'GaussianProb',
'GaussianWeight',
'DOG',
'SmallWorld',
'ScaleFree'
]


def _size2len(size):
if isinstance(size, int):
return size
elif isinstance(size, (tuple, list)):
a = 1
for b in size:
a *= b
return a
else:
raise ValueError

if hasattr(nb.core, 'dispatcher'):
from numba.core.dispatcher import Dispatcher
else:
from numba.core import Dispatcher

def _grid_four(height, width, row, include_self):
conn_i = []
conn_j = []

for col in range(width):
i_index = (row * width) + col
if 0 <= row - 1 < height:
j_index = ((row - 1) * width) + col
conn_i.append(i_index)
conn_j.append(j_index)
if 0 <= row + 1 < height:
j_index = ((row + 1) * width) + col
conn_i.append(i_index)
conn_j.append(j_index)
if 0 <= col - 1 < width:
j_index = (row * width) + col - 1
conn_i.append(i_index)
conn_j.append(j_index)
if 0 <= col + 1 < width:
j_index = (row * width) + col + 1
conn_i.append(i_index)
conn_j.append(j_index)
if include_self:
conn_i.append(i_index)
conn_j.append(i_index)
return conn_i, conn_j


def _grid_n(height, width, row, n, include_self):
conn_i = []
conn_j = []
for col in range(width):
i_index = (row * width) + col
for row_diff in range(-n, n + 1):
for col_diff in range(-n, n + 1):
if (not include_self) and (row_diff == col_diff == 0):
continue
if 0 <= row + row_diff < height and 0 <= col + col_diff < width:
j_index = ((row + row_diff) * width) + col + col_diff
conn_i.append(i_index)
conn_j.append(j_index)
return conn_i, conn_j


def _gaussian_weight(pre_i, pre_width, pre_height,
num_post, post_width, post_height,
w_max, w_min, sigma, normalize, include_self):
conn_i = []
conn_j = []
conn_w = []

# get normalized coordination
pre_coords = (pre_i // pre_width, pre_i % pre_width)
if normalize:
pre_coords = (pre_coords[0] / (pre_height - 1) if pre_height > 1 else 1.,
pre_coords[1] / (pre_width - 1) if pre_width > 1 else 1.)

for post_i in range(num_post):
if (pre_i == post_i) and (not include_self):
continue

# get normalized coordination
post_coords = (post_i // post_width, post_i % post_width)
if normalize:
post_coords = (post_coords[0] / (post_height - 1) if post_height > 1 else 1.,
post_coords[1] / (post_width - 1) if post_width > 1 else 1.)

# Compute Euclidean distance between two coordinates
distance = (pre_coords[0] - post_coords[0]) ** 2
distance += (pre_coords[1] - post_coords[1]) ** 2
# get weight and conn
value = w_max * np.exp(-distance / (2.0 * sigma ** 2))
if value > w_min:
conn_i.append(pre_i)
conn_j.append(post_i)
conn_w.append(value)
return conn_i, conn_j, conn_w

__all__ = ['One2One', 'one2one',
'All2All', 'all2all',
'GridFour', 'grid_four',
'GridEight', 'grid_eight',
'GridN',
'FixedPostNum', 'FixedPreNum', 'FixedProb',
'GaussianProb', 'GaussianWeight', 'DOG',
'SmallWorld', 'ScaleFree']

def _gaussian_prob(pre_i, pre_width, pre_height,
num_post, post_width, post_height,
p_min, sigma, normalize, include_self):
conn_i = []
conn_j = []
conn_p = []

# get normalized coordination
pre_coords = (pre_i // pre_width, pre_i % pre_width)
if normalize:
pre_coords = (pre_coords[0] / (pre_height - 1) if pre_height > 1 else 1.,
pre_coords[1] / (pre_width - 1) if pre_width > 1 else 1.)

for post_i in range(num_post):
if (pre_i == post_i) and (not include_self):
continue

class One2One(base.Connector):
# get normalized coordination
post_coords = (post_i // post_width, post_i % post_width)
if normalize:
post_coords = (post_coords[0] / (post_height - 1) if post_height > 1 else 1.,
post_coords[1] / (post_width - 1) if post_width > 1 else 1.)

# Compute Euclidean distance between two coordinates
distance = (pre_coords[0] - post_coords[0]) ** 2
distance += (pre_coords[1] - post_coords[1]) ** 2
# get weight and conn
value = np.exp(-distance / (2.0 * sigma ** 2))
if value > p_min:
conn_i.append(pre_i)
conn_j.append(post_i)
conn_p.append(value)
return conn_i, conn_j, conn_p


def _dog(pre_i, pre_width, pre_height,
num_post, post_width, post_height,
w_max_p, w_max_n, w_min, sigma_p, sigma_n,
normalize, include_self):
conn_i = []
conn_j = []
conn_w = []

# get normalized coordination
pre_coords = (pre_i // pre_width, pre_i % pre_width)
if normalize:
pre_coords = (pre_coords[0] / (pre_height - 1) if pre_height > 1 else 1.,
pre_coords[1] / (pre_width - 1) if pre_width > 1 else 1.)

for post_i in range(num_post):
if (pre_i == post_i) and (not include_self):
continue

# get normalized coordination
post_coords = (post_i // post_width, post_i % post_width)
if normalize:
post_coords = (post_coords[0] / (post_height - 1) if post_height > 1 else 1.,
post_coords[1] / (post_width - 1) if post_width > 1 else 1.)

# Compute Euclidean distance between two coordinates
distance = (pre_coords[0] - post_coords[0]) ** 2
distance += (pre_coords[1] - post_coords[1]) ** 2
# get weight and conn
value = w_max_p * np.exp(-distance / (2.0 * sigma_p ** 2)) - \
w_max_n * np.exp(-distance / (2.0 * sigma_n ** 2))
if np.abs(value) > w_min:
conn_i.append(pre_i)
conn_j.append(post_i)
conn_w.append(value)
return conn_i, conn_j, conn_w


if nb is not None:
_grid_four = nb.njit(_grid_four)
_grid_n = nb.njit(_grid_n)
_gaussian_weight = nb.njit(_gaussian_weight)
_gaussian_prob = nb.njit(_gaussian_prob)
_dog = nb.njit(_dog)


class One2One(Connector):
"""
Connect two neuron groups one by one. This means
The two neuron groups should have the same size.
"""

def __init__(self):
super(One2One, self).__init__()

def __call__(self, pre_indices, post_indices):
pre_indices = np.asarray(pre_indices)
post_indices = np.asarray(post_indices)
self.pre_ids = np.ascontiguousarray(pre_indices.flatten(), dtype=np.int_)
self.post_ids = np.ascontiguousarray(post_indices.flatten(), dtype=np.int_)
def __call__(self, pre_size, post_size):
try:
assert np.size(self.pre_ids) == np.size(self.post_ids)
assert pre_size == post_size
except AssertionError:
raise errors.ModelUseError(f'One2One connection must be defined in two groups with the same size, '
f'but we got {np.size(self.pre_ids)} != {np.size(self.post_ids)}.')
if self.num_pre is None:
self.num_pre = pre_indices.max()
if self.num_post is None:
self.num_post = post_indices.max()
f'but we got {pre_size} != {post_size}.')

length = _size2len(pre_size)
self.num_pre = length
self.num_post = length

self.pre_ids = backend.arange(length)
self.post_ids = backend.arange(length)
return self


one2one = One2One()


class All2All(base.Connector):
class All2All(Connector):
"""Connect each neuron in first group to all neurons in the
post-synaptic neuron groups. It means this kind of conn
will create (num_pre x num_post) synapses.
@@ -59,73 +237,49 @@ class All2All(base.Connector):
self.include_self = include_self
super(All2All, self).__init__()

def __call__(self, pre_indices, post_indices):
pre_indices = pre_indices.flatten()
post_indices = post_indices.flatten()
num_pre, num_post = len(pre_indices), len(post_indices)
mat = np.ones((num_pre, num_post))
def __call__(self, pre_size, post_size):
pre_len = _size2len(pre_size)
post_len = _size2len(post_size)
self.num_pre = pre_len
self.num_post = post_len

mat = np.ones((pre_len, post_len))
if not self.include_self:
for i in range(min([num_post, num_pre])):
mat[i, i] = 0
eye = np.arange(min([pre_len, post_len]))
self.conn_mat[eye, eye] = 0
pre_ids, post_ids = np.where(mat > 0)
self.pre_ids = np.ascontiguousarray(pre_ids, dtype=np.int_)
self.post_ids = np.ascontiguousarray(post_ids, dtype=np.int_)
if self.num_pre is None:
self.num_pre = pre_indices.max()
if self.num_post is None:
self.num_post = post_indices.max()
self.pre_ids = backend.as_tensor(np.ascontiguousarray(pre_ids))
self.post_ids = backend.as_tensor(np.ascontiguousarray(post_ids))
self.conn_mat = backend.as_tensor(mat)
return self


all2all = All2All(include_self=True)

@nb.njit
def _grid_four(height, width, row, include_self):
conn_i = []
conn_j = []

for col in range(width):
i_index = (row * width) + col
if 0 <= row - 1 < height:
j_index = ((row - 1) * width) + col
conn_i.append(i_index)
conn_j.append(j_index)
if 0 <= row + 1 < height:
j_index = ((row + 1) * width) + col
conn_i.append(i_index)
conn_j.append(j_index)
if 0 <= col - 1 < width:
j_index = (row * width) + col - 1
conn_i.append(i_index)
conn_j.append(j_index)
if 0 <= col + 1 < width:
j_index = (row * width) + col + 1
conn_i.append(i_index)
conn_j.append(j_index)
if include_self:
conn_i.append(i_index)
conn_j.append(i_index)
return conn_i, conn_j


class GridFour(base.Connector):
class GridFour(Connector):
"""The nearest four neighbors conn method."""

def __init__(self, include_self=False):
super(GridFour, self).__init__()
self.include_self = include_self

def __call__(self, pre_indices, post_indices=None):
if post_indices is not None:
def __call__(self, pre_size, post_size=None):
self.num_pre = _size2len(pre_size)
if post_size is not None:
try:
assert np.shape(pre_indices) == np.shape(post_indices)
assert pre_size == post_size
except AssertionError:
raise errors.ModelUseError(f'The shape of pre-synaptic group should be the same with the post group. '
f'But we got {np.shape(pre_indices)} != {np.shape(post_indices)}.')
raise errors.ModelUseError(f'The shape of pre-synaptic group should be the same with the '
f'post group. But we got {pre_size} != {post_size}.')
self.num_post = _size2len(post_size)
else:
self.num_post = self.num_pre

if len(pre_indices.shape) == 1:
height, width = pre_indices.shape[0], 1
elif len(pre_indices.shape) == 2:
height, width = pre_indices.shape
if len(pre_size) == 1:
height, width = pre_size[0], 1
elif len(pre_size) == 2:
height, width = pre_size
else:
raise errors.ModelUseError('Currently only support two-dimensional geometry.')
conn_i = []
@@ -134,43 +288,15 @@ class GridFour(base.Connector):
a = _grid_four(height, width, row, include_self=self.include_self)
conn_i.extend(a[0])
conn_j.extend(a[1])
conn_i = np.asarray(conn_i)
conn_j = np.asarray(conn_j)

pre_indices = pre_indices.flatten()
self.pre_ids = pre_indices[conn_i]
if self.num_pre is None:
self.num_pre = pre_indices.max()
if post_indices is None:
self.post_ids = pre_indices[conn_j]
else:
post_indices = post_indices.flatten()
self.post_ids = post_indices[conn_j]
if self.num_post is None:
self.num_post = post_indices.max()
self.pre_ids = backend.as_tensor(conn_i)
self.post_ids = backend.as_tensor(conn_j)
return self


grid_four = GridFour()


@nb.njit
def _grid_n(height, width, row, n, include_self):
conn_i = []
conn_j = []
for col in range(width):
i_index = (row * width) + col
for row_diff in range(-n, n + 1):
for col_diff in range(-n, n + 1):
if (not include_self) and (row_diff == col_diff == 0):
continue
if 0 <= row + row_diff < height and 0 <= col + col_diff < width:
j_index = ((row + row_diff) * width) + col + col_diff
conn_i.append(i_index)
conn_j.append(j_index)
return conn_i, conn_j


class GridN(base.Connector):
class GridN(Connector):
"""The nearest (2*N+1) * (2*N+1) neighbors conn method.

Parameters
@@ -196,18 +322,23 @@ class GridN(base.Connector):
self.n = n
self.include_self = include_self

def __call__(self, pre_indices, post_indices=None):
if post_indices is not None:
def __call__(self, pre_size, post_size=None):
self.num_pre = _size2len(pre_size)
if post_size is not None:
try:
assert np.shape(pre_indices) == np.shape(post_indices)
assert pre_size == post_size
except AssertionError:
raise errors.ModelUseError(f'The shape of pre-synaptic group should be the same with the post group. '
f'But we got {np.shape(pre_indices)} != {np.shape(post_indices)}.')
raise errors.ModelUseError(
f'The shape of pre-synaptic group should be the same with the post group. '
f'But we got {pre_size} != {post_size}.')
self.num_post = _size2len(post_size)
else:
self.num_post = self.num_pre

if len(pre_indices.shape) == 1:
height, width = pre_indices.shape[0], 1
elif len(pre_indices.shape) == 2:
height, width = pre_indices.shape
if len(pre_size) == 1:
height, width = pre_size[0], 1
elif len(pre_size) == 2:
height, width = pre_size
else:
raise errors.ModelUseError('Currently only support two-dimensional geometry.')

@@ -218,20 +349,9 @@ class GridN(base.Connector):
n=self.n, include_self=self.include_self)
conn_i.extend(res[0])
conn_j.extend(res[1])
conn_i = np.asarray(conn_i, dtype=np.int_)
conn_j = np.asarray(conn_j, dtype=np.int_)

pre_indices = pre_indices.flatten()
if self.num_pre is None:
self.num_pre = pre_indices.max()
self.pre_ids = pre_indices[conn_i]
if post_indices is None:
self.post_ids = pre_indices[conn_j]
else:
post_indices = post_indices.flatten()
self.post_ids = post_indices[conn_j]
if self.num_post is None:
self.num_post = post_indices.max()
self.pre_ids = backend.as_tensor(conn_i)
self.post_ids = backend.as_tensor(conn_j)
return self


class GridEight(GridN):
@@ -244,7 +364,7 @@ class GridEight(GridN):
grid_eight = GridEight()


class FixedProb(base.Connector):
class FixedProb(Connector):
"""Connect the post-synaptic neurons with fixed probability.

Parameters
@@ -263,27 +383,23 @@ class FixedProb(base.Connector):
self.include_self = include_self
self.seed = seed

def __call__(self, pre_indices, post_indices):
pre_indices = pre_indices.flatten()
post_indices = post_indices.flatten()
def __call__(self, pre_size, post_size):
num_pre, num_post = _size2len(pre_size), _size2len(post_size)
self.num_pre, self.num_post = num_pre, num_post

num_pre, num_post = len(pre_indices), len(post_indices)
prob_mat = np.random.random(size=(num_pre, num_post))
if not self.include_self:
diag_index = np.arange(min([num_pre, num_post]))
prob_mat[diag_index, diag_index] = 1.
conn_mat = prob_mat < self.prob
conn_mat = np.array(prob_mat < self.prob, dtype=np.int_)
pre_ids, post_ids = np.where(conn_mat)
self.conn_mat = np.float_(conn_mat)
self.pre_ids = pre_indices[pre_ids]
self.post_ids = post_indices[post_ids]
if self.num_pre is None:
self.num_pre = pre_indices.max()
if self.num_post is None:
self.num_post = post_indices.max()
self.conn_mat = backend.as_tensor(conn_mat)
self.pre_ids = backend.as_tensor(np.ascontiguousarray(pre_ids))
self.post_ids = backend.as_tensor(np.ascontiguousarray(post_ids))
return self


class FixedPreNum(base.Connector):
class FixedPreNum(Connector):
"""Connect the pre-synaptic neurons with fixed number for each
post-synaptic neuron.

@@ -310,10 +426,9 @@ class FixedPreNum(base.Connector):
self.include_self = include_self
self.seed = seed

def __call__(self, pre_indices, post_indices):
pre_indices = pre_indices.flatten()
post_indices = post_indices.flatten()
num_pre, num_post = len(pre_indices), len(post_indices)
def __call__(self, pre_size, post_size):
num_pre, num_post = _size2len(pre_size), _size2len(post_size)
self.num_pre, self.num_post = num_pre, num_post
num = self.num if isinstance(self.num, int) else int(self.num * num_pre)
assert num <= num_pre, f'"num" must be less than "num_pre", but got {num} > {num_pre}'
prob_mat = np.random.random(size=(num_pre, num_post))
@@ -321,15 +436,14 @@ class FixedPreNum(base.Connector):
diag_index = np.arange(min([num_pre, num_post]))
prob_mat[diag_index, diag_index] = 1.1
arg_sort = np.argsort(prob_mat, axis=0)[:num]
self.pre_ids = np.asarray(np.concatenate(arg_sort), dtype=np.int64)
self.post_ids = np.asarray(np.repeat(np.arange(num_post), num_pre), dtype=np.int64)
if self.num_pre is None:
self.num_pre = pre_indices.max()
if self.num_post is None:
self.num_post = post_indices.max()
pre_ids = np.asarray(np.concatenate(arg_sort), dtype=np.int_)
post_ids = np.asarray(np.repeat(np.arange(num_post), num_pre), dtype=np.int_)
self.pre_ids = backend.as_tensor(pre_ids)
self.post_ids = backend.as_tensor(post_ids)
return self


class FixedPostNum(base.Connector):
class FixedPostNum(Connector):
"""Connect the post-synaptic neurons with fixed number for each
pre-synaptic neuron.

@@ -356,11 +470,11 @@ class FixedPostNum(base.Connector):
self.seed = seed
super(FixedPostNum, self).__init__()

def __call__(self, pre_indices, post_indices):
pre_indices = pre_indices.flatten()
post_indices = post_indices.flatten()
num_pre = len(pre_indices)
num_post = len(post_indices)
def __call__(self, pre_size, post_size):
num_pre = _size2len(pre_size)
num_post = _size2len(post_size)
self.num_pre = num_pre
self.num_post = num_post
num = self.num if isinstance(self.num, int) else int(self.num * num_post)
assert num <= num_post, f'"num" must be less than "num_post", but got {num} > {num_post}'
prob_mat = np.random.random(size=(num_pre, num_post))
@@ -368,51 +482,14 @@ class FixedPostNum(base.Connector):
diag_index = np.arange(min([num_pre, num_post]))
prob_mat[diag_index, diag_index] = 1.1
arg_sort = np.argsort(prob_mat, axis=1)[:, num]
self.post_ids = np.asarray(np.concatenate(arg_sort), dtype=np.int64)
self.pre_ids = np.asarray(np.repeat(np.arange(num_pre), num_post), dtype=np.int64)
if self.num_pre is None:
self.num_pre = pre_indices.max()
if self.num_post is None:
self.num_post = post_indices.max()


@nb.njit
def _gaussian_weight(pre_i, pre_width, pre_height,
num_post, post_width, post_height,
w_max, w_min, sigma, normalize, include_self):
conn_i = []
conn_j = []
conn_w = []

# get normalized coordination
pre_coords = (pre_i // pre_width, pre_i % pre_width)
if normalize:
pre_coords = (pre_coords[0] / (pre_height - 1) if pre_height > 1 else 1.,
pre_coords[1] / (pre_width - 1) if pre_width > 1 else 1.)

for post_i in range(num_post):
if (pre_i == post_i) and (not include_self):
continue

# get normalized coordination
post_coords = (post_i // post_width, post_i % post_width)
if normalize:
post_coords = (post_coords[0] / (post_height - 1) if post_height > 1 else 1.,
post_coords[1] / (post_width - 1) if post_width > 1 else 1.)

# Compute Euclidean distance between two coordinates
distance = (pre_coords[0] - post_coords[0]) ** 2
distance += (pre_coords[1] - post_coords[1]) ** 2
# get weight and conn
value = w_max * np.exp(-distance / (2.0 * sigma ** 2))
if value > w_min:
conn_i.append(pre_i)
conn_j.append(post_i)
conn_w.append(value)
return conn_i, conn_j, conn_w
post_ids = np.asarray(np.concatenate(arg_sort), dtype=np.int64)
pre_ids = np.asarray(np.repeat(np.arange(num_pre), num_post), dtype=np.int64)
self.pre_ids = backend.as_tensor(pre_ids)
self.post_ids = backend.as_tensor(post_ids)
return self


class GaussianWeight(base.Connector):
class GaussianWeight(Connector):
"""Builds a Gaussian conn pattern between the two populations, where
the weights decay with gaussian function.

@@ -451,13 +528,15 @@ class GaussianWeight(base.Connector):
self.normalize = normalize
self.include_self = include_self

def __call__(self, pre_indices, post_indices):
num_pre = np.size(pre_indices)
num_post = np.size(post_indices)
assert np.ndim(pre_indices) == 2
assert np.ndim(post_indices) == 2
pre_height, pre_width = pre_indices.shape
post_height, post_width = post_indices.shape
def __call__(self, pre_size, post_size):
num_pre = _size2len(pre_size)
num_post = _size2len(post_size)
self.num_pre = num_pre
self.num_post = num_post
assert len(pre_size) == 2
assert len(post_size) == 2
pre_height, pre_width = pre_size
post_height, post_width = post_size

# get the connections and weights
i, j, w = [], [], []
@@ -480,54 +559,13 @@ class GaussianWeight(base.Connector):
pre_ids = np.asarray(i, dtype=np.int_)
post_ids = np.asarray(j, dtype=np.int_)
w = np.asarray(w, dtype=np.float_)
pre_indices = pre_indices.flatten()
post_indices = post_indices.flatten()
self.pre_ids = pre_indices[pre_ids]
self.post_ids = post_indices[post_ids]
self.weights = w
if self.num_pre is None:
self.num_pre = pre_indices.max()
if self.num_post is None:
self.num_post = post_indices.max()


@nb.njit
def _gaussian_prob(pre_i, pre_width, pre_height,
num_post, post_width, post_height,
p_min, sigma, normalize, include_self):
conn_i = []
conn_j = []
conn_p = []

# get normalized coordination
pre_coords = (pre_i // pre_width, pre_i % pre_width)
if normalize:
pre_coords = (pre_coords[0] / (pre_height - 1) if pre_height > 1 else 1.,
pre_coords[1] / (pre_width - 1) if pre_width > 1 else 1.)
self.pre_ids = backend.as_tensor(pre_ids)
self.post_ids = backend.as_tensor(post_ids)
self.weights = backend.as_tensor(w)
return self

for post_i in range(num_post):
if (pre_i == post_i) and (not include_self):
continue

# get normalized coordination
post_coords = (post_i // post_width, post_i % post_width)
if normalize:
post_coords = (post_coords[0] / (post_height - 1) if post_height > 1 else 1.,
post_coords[1] / (post_width - 1) if post_width > 1 else 1.)

# Compute Euclidean distance between two coordinates
distance = (pre_coords[0] - post_coords[0]) ** 2
distance += (pre_coords[1] - post_coords[1]) ** 2
# get weight and conn
value = np.exp(-distance / (2.0 * sigma ** 2))
if value > p_min:
conn_i.append(pre_i)
conn_j.append(post_i)
conn_p.append(value)
return conn_i, conn_j, conn_p


class GaussianProb(base.Connector):
class GaussianProb(Connector):
"""Builds a Gaussian conn pattern between the two populations, where
the conn probability decay according to the gaussian function.

@@ -559,13 +597,13 @@ class GaussianProb(base.Connector):
self.normalize = normalize
self.include_self = include_self

def __call__(self, pre_indices, post_indices):
num_pre = np.size(pre_indices)
num_post = np.size(post_indices)
assert np.ndim(pre_indices) == 2
assert np.ndim(post_indices) == 2
pre_height, pre_width = pre_indices.shape
post_height, post_width = post_indices.shape
def __call__(self, pre_size, post_size):
self.num_pre = num_pre = _size2len(pre_size)
self.num_post = num_post = _size2len(post_size)
assert len(pre_size) == 2
assert len(post_size) == 2
pre_height, pre_width = pre_size
post_height, post_width = post_size

# get the connections
i, j, p = [], [], [] # conn_i, conn_j, probabilities
@@ -587,55 +625,12 @@ class GaussianProb(base.Connector):
selected_idxs = np.where(np.random.random(len(p)) < p)[0]
i = np.asarray(i, dtype=np.int_)[selected_idxs]
j = np.asarray(j, dtype=np.int_)[selected_idxs]
pre_indices = pre_indices.flatten()
post_indices = post_indices.flatten()
self.pre_ids = pre_indices[i]
self.post_ids = post_indices[j]
if self.num_pre is None:
self.num_pre = pre_indices.max()
if self.num_post is None:
self.num_post = post_indices.max()
self.pre_ids = backend.as_tensor(i)
self.post_ids = backend.as_tensor(j)
return self


@nb.njit
def _dog(pre_i, pre_width, pre_height,
num_post, post_width, post_height,
w_max_p, w_max_n, w_min, sigma_p, sigma_n,
normalize, include_self):
conn_i = []
conn_j = []
conn_w = []

# get normalized coordination
pre_coords = (pre_i // pre_width, pre_i % pre_width)
if normalize:
pre_coords = (pre_coords[0] / (pre_height - 1) if pre_height > 1 else 1.,
pre_coords[1] / (pre_width - 1) if pre_width > 1 else 1.)

for post_i in range(num_post):
if (pre_i == post_i) and (not include_self):
continue

# get normalized coordination
post_coords = (post_i // post_width, post_i % post_width)
if normalize:
post_coords = (post_coords[0] / (post_height - 1) if post_height > 1 else 1.,
post_coords[1] / (post_width - 1) if post_width > 1 else 1.)

# Compute Euclidean distance between two coordinates
distance = (pre_coords[0] - post_coords[0]) ** 2
distance += (pre_coords[1] - post_coords[1]) ** 2
# get weight and conn
value = w_max_p * np.exp(-distance / (2.0 * sigma_p ** 2)) - \
w_max_n * np.exp(-distance / (2.0 * sigma_n ** 2))
if np.abs(value) > w_min:
conn_i.append(pre_i)
conn_j.append(post_i)
conn_w.append(value)
return conn_i, conn_j, conn_w


class DOG(base.Connector):
class DOG(Connector):
"""Builds a Difference-Of-Gaussian (dog) conn pattern between the two populations.

Mathematically,
@@ -671,13 +666,13 @@ class DOG(base.Connector):
self.normalize = normalize
self.include_self = include_self

def __call__(self, pre_indices, post_indices):
num_pre = np.size(pre_indices)
num_post = np.size(post_indices)
assert np.ndim(pre_indices) == 2
assert np.ndim(post_indices) == 2
pre_height, pre_width = pre_indices.shape
post_height, post_width = post_indices.shape
def __call__(self, pre_size, post_size):
self.num_pre = num_pre = _size2len(pre_size)
self.num_post = num_post = _size2len(post_size)
assert len(pre_size) == 2
assert len(post_size) == 2
pre_height, pre_width = pre_size
post_height, post_width = post_size

# get the connections and weights
i, j, w = [], [], [] # conn_i, conn_j, weights
@@ -703,22 +698,17 @@ class DOG(base.Connector):
i = np.asarray(i, dtype=np.int_)
j = np.asarray(j, dtype=np.int_)
w = np.asarray(w, dtype=np.float_)
pre_indices = pre_indices.flatten()
post_indices = post_indices.flatten()
self.pre_ids = pre_indices[i]
self.post_ids = post_indices[j]
self.weights = w
if self.num_pre is None:
self.num_pre = pre_indices.max()
if self.num_post is None:
self.num_post = post_indices.max()


class ScaleFree(base.Connector):
self.pre_ids = backend.as_tensor(i)
self.post_ids = backend.as_tensor(j)
self.weights = backend.as_tensor(w)
return self


class ScaleFree(Connector):
def __init__(self):
raise NotImplementedError


class SmallWorld(base.Connector):
class SmallWorld(Connector):
def __init__(self):
raise NotImplementedError

+ 0
- 8
brainpy/core/__init__.py View File

@@ -1,8 +0,0 @@
# -*- coding: utf-8 -*-

from .base import *
from .types import *
from .runner import *
from .neurons import *
from .synapses import *
from .network import *

+ 0
- 609
brainpy/core/base.py View File

@@ -1,609 +0,0 @@
# -*- coding: utf-8 -*-

import inspect
import re
import time
from copy import deepcopy

import numpy as np
from numba import cuda

from . import constants
from . import runner
from . import types
from . import utils
from .. import errors
from .. import profile
from .. import tools

__all__ = [
'ObjType',
'Ensemble',
'ParsUpdate',
]


class ObjType(object):
"""The base type of neuron and synapse.

Parameters
----------
name : str, optional
Model name.
"""

def __init__(self, ST, name, steps, requires=None, mode='vector', hand_overs=None, ):
self.mode = mode
self.name = name
if not isinstance(ST, types.ObjState):
raise errors.ModelDefError('"ST" must be an instance of ObjState.')
self.ST = ST

# requires
# ---------
if requires is None:
requires = dict()
if not isinstance(requires, dict):
raise errors.ModelDefError('"requires" only supports dict.')
self.requires = requires
for k, v in requires.items():
if isinstance(v, type):
raise errors.ModelDefError(f'In "requires", you must instantiate '
f'the type checker of "{k}". '
f'Like "{v.__name__}()".')
if not isinstance(v, types.TypeChecker):
raise errors.ModelDefError(f'In "requires", each value must be a '
f'{types.TypeChecker.__name__}, '
f'but got "{type(v)}" for "{k}".')

# steps
# ------
self.steps = []
self.step_names = []
self.step_scopes = dict()
self.step_args = set()
step_vars = set()
if callable(steps):
steps = [steps]
elif isinstance(steps, (list, tuple)):
steps = list(steps)
else:
raise errors.ModelDefError('"steps" must be a callable, or a '
'list/tuple of callable functions.')
for func in steps:
if not callable(func):
raise errors.ModelDefError('"steps" must be a list/tuple of callable functions.')

# function name
func_name = tools.get_func_name(func, replace=True)
self.step_names.append(func_name)

# function arg
for arg in inspect.getfullargspec(func).args:
if arg in constants.ARG_KEYWORDS:
continue
self.step_args.add(arg)

# function scope
scope = utils.get_func_scope(func, include_dispatcher=True)
for k, v in scope.items():
if k in self.step_scopes:
if v != self.step_scopes[k]:
raise errors.ModelDefError(
f'Find scope variable {k} have different values in '
f'{self.name}: {k} = {v} and {k} = {self.step_scopes[k]}.\n'
f'This maybe cause a grievous mistake in the future. '
f'Please change!')
self.step_scopes[k] = v

# function
self.steps.append(func)

# set attribute
setattr(self, func_name, func)

# get the STATE variables
step_vars.update(re.findall(r'ST\[[\'"](\w+)[\'"]\]', tools.get_main_code(func)))

self.step_args = list(self.step_args)

# variables
# ----------
self.variables = ST._vars
for var in step_vars:
if var not in self.variables:
raise errors.ModelDefError(f'Variable "{var}" is used in {self.name}, '
f'but not defined in "ST".')

# integrators
# -----------
self.integrators = []
for step in self.steps:
self.integrators.extend(utils.find_integrators(step))
self.integrators = list(set(self.integrators))

# delay keys
# ----------
self._delay_keys = []

# hand overs
# ---------------
if hand_overs is not None:
if not isinstance(hand_overs, dict):
raise errors.ModelUseError('"hand_overs" must be a dict.')
else:
hand_overs = dict()
self.hand_overs = hand_overs

def __str__(self):
return f'{self.name}'


class ParsUpdate(dict):
"""Class for parameter updating.

Structure of ``ParsUpdate``

- origins : original parameters
- num : number of the neurons
- updates : parameters to update
- heters : parameters to update, and they are heterogeneous
- model : the model which this ParsUpdate belongs to

"""

def __init__(self, all_pars, num, model):
assert isinstance(all_pars, dict)
assert isinstance(num, int)

super(ParsUpdate, self).__init__(origins=all_pars,
num=num,
heters=dict(),
updates=dict(),
model=model)

def __setitem__(self, key, value):
# check the existence of "key"
if key not in self.origins:
raise errors.ModelUseError(f'Parameter "{key}" may be not defined in '
f'"{self.model.name}" variable scope.\n'
f'Or, "{key}" is used to compute an '
f'intermediate variable, and is not '
f'directly used by the step functions.')

# check value size
val_size = np.size(value)
if val_size != 1:
if val_size != self.num:
raise errors.ModelUseError(
f'The size of parameter "{key}" is wrong, "{val_size}" != 1 '
f'and "{val_size}" != {self.num}.')
if np.size(self.origins[key]) != val_size: # maybe default value is a heterogeneous value
self.heters[key] = value

# update
if profile.run_on_cpu():
self.updates[key] = value
else:
if isinstance(value, (int, float)):
self.updates[key] = value
elif value.__class__.__name__ == 'DeviceNDArray':
self.updates[key] = value
elif isinstance(value, np.ndarray):
self.updates[key] = cuda.to_device(value)
else:
raise ValueError(f'GPU mode cannot support {type(value)}.')

def __getitem__(self, item):
if item in self.updates:
return self.updates[item]
elif item in self.origins:
return self.origins[item]
else:
super(ParsUpdate, self).__getitem__(item)

def __dir__(self):
return str(self.all)

def keys(self):
"""All parameters can be updated.

Returns
-------
keys : list
List of parameter names.
"""
return self.origins.keys()

def items(self):
"""All parameters, including keys and values.

Returns
-------
items : iterable
The iterable parameter items.
"""
return self.all.items()

def get(self, item):
"""Get the parameter value by its key.

Parameters
----------
item : str
Parameter name.

Returns
-------
value : any
Parameter value.
"""
return self.all.__getitem__(item)

@property
def origins(self):
return super(ParsUpdate, self).__getitem__('origins')

@property
def heters(self):
return super(ParsUpdate, self).__getitem__('heters')

@property
def updates(self):
return super(ParsUpdate, self).__getitem__('updates')

@property
def num(self):
return super(ParsUpdate, self).__getitem__('num')

@property
def model(self):
return super(ParsUpdate, self).__getitem__('model')

@property
def all(self):
origins = deepcopy(self.origins)
origins.update(self.updates)
return origins


class Ensemble(object):
"""Base Ensemble class.

Parameters
----------
name : str
Name of the (neurons/synapses) ensemble.
num : int
The number of the neurons/synapses.
model : ObjType
The (neuron/synapse) model.
monitors : list, tuple, None
Variables to monitor.
pars_update : dict, None
Parameters to update.
cls_type : str
Class type.
"""

def __init__(self, name, num, model, monitors, pars_update, cls_type, satisfies=None, ):
# class type
# -----------
if not cls_type in [constants.NEU_GROUP_TYPE, constants.SYN_CONN_TYPE]:
raise errors.ModelUseError(f'Only support "{constants.NEU_GROUP_TYPE}" '
f'and "{constants.SYN_CONN_TYPE}".')
self._cls_type = cls_type

# model
# -----
self.model = model

# name
# ----
self.name = name
if not self.name.isidentifier():
raise errors.ModelUseError(
f'"{self.name}" isn\'t a valid identifier according to Python '
f'language definition. Please choose another name.')

# num
# ---
self.num = num

# parameters
# ----------
self.pars = ParsUpdate(all_pars=model.step_scopes, num=num, model=model)
pars_update = dict() if pars_update is None else pars_update
if not isinstance(pars_update, dict):
raise errors.ModelUseError('"pars_update" must be a dict.')
for k, v in pars_update.items():
self.pars[k] = v

# monitors
# ---------
self.mon = tools.DictPlus()
self._mon_items = []
if monitors is not None:
if isinstance(monitors, (list, tuple)):
for var in monitors:
if isinstance(var, str):
self._mon_items.append((var, None))
self.mon[var] = np.empty((1, 1), dtype=np.float_)
elif isinstance(var, (tuple, list)):
self._mon_items.append((var[0], var[1]))
self.mon[var[0]] = np.empty((1, 1), dtype=np.float_)
else:
raise errors.ModelUseError(f'Unknown monitor item: {str(var)}')
elif isinstance(monitors, dict):
for k, v in monitors.items():
self._mon_items.append((k, v))
self.mon[k] = np.empty((1, 1), dtype=np.float_)
else:
raise errors.ModelUseError(f'Unknown monitors type: {type(monitors)}')

# runner
# -------
self.runner = runner.Runner(ensemble=self)

# hand overs
# ----------
# 1. attributes
# 2. functions
for attr_key, attr_val in model.hand_overs.items():
setattr(self, attr_key, attr_val)

# satisfies
# ---------
if satisfies is not None:
if not isinstance(satisfies, dict):
raise errors.ModelUseError('"satisfies" must be dict.')
for key, val in satisfies.items():
setattr(self, key, val)

def _is_state_attr(self, arg):
try:
attr = getattr(self, arg)
except AttributeError:
return False
if self._cls_type == constants.NEU_GROUP_TYPE:
return isinstance(attr, types.NeuState)
elif self._cls_type == constants.SYN_CONN_TYPE:
return isinstance(attr, types.SynState)
else:
raise ValueError

def type_checking(self):
"""Check the data type needed for step function.
"""
# 1. check ST and its type
if not hasattr(self, 'ST'):
raise errors.ModelUseError(f'"{self.name}" doesn\'t have "ST" attribute.')
try:
self.model.ST.check(self.ST)
except errors.TypeMismatchError:
raise errors.ModelUseError(f'"{self.name}.ST" doesn\'t satisfy TypeChecker "{str(self.model.ST)}".')

# 2. check requires and its type
for key, type_checker in self.model.requires.items():
if not hasattr(self, key):
raise errors.ModelUseError(f'"{self.name}" doesn\'t have "{key}" attribute.')
try:
type_checker.check(getattr(self, key))
except errors.TypeMismatchError:
raise errors.ModelUseError(f'"{self.name}.{key}" doesn\'t satisfy TypeChecker "{str(type_checker)}".')

# 3. check data (function arguments) needed
for i, func in enumerate(self.model.steps):
for arg in inspect.getfullargspec(func).args:
if not (arg in constants.ARG_KEYWORDS + ['self']) and not hasattr(self, arg):
raise errors.ModelUseError(
f'Function "{self.model.step_names[i]}" in "{self.model.name}" '
f'requires "{arg}" as argument, but "{arg}" is not defined in "{self.name}".')

def reshape_mon(self, run_length):
for key, val in self.mon.items():
if key == 'ts':
continue
shape = val.shape
if run_length < shape[0]:
self.mon[key] = val[:run_length]
elif run_length > shape[0]:
append = np.zeros((run_length - shape[0],) + shape[1:])
self.mon[key] = np.vstack([val, append])
if profile.run_on_gpu():
for key, val in self.mon.items():
key_gpu = f'mon_{key}_cuda'
val_gpu = cuda.to_device(val)
setattr(self.runner, key_gpu, val_gpu)
self.runner.gpu_data[key_gpu] = val_gpu

def build(self, inputs=None, mon_length=0):
"""Build the object for running.

Parameters
----------
inputs : list, tuple
The object inputs.
mon_length : int
The monitor length.

Returns
-------
calls : list, tuple
The code lines to call step functions.
"""

# 1. prerequisite
# ---------------
if profile.run_on_gpu():
if self.model.mode != constants.SCALAR_MODE:
raise errors.ModelUseError(f'GPU mode only support scalar-based mode. '
f'But {self.model} is a {self.model.mode}-based model.')
self.type_checking()

# 2. Code results
# ---------------
code_results = dict()
# inputs
if inputs:
r = self.runner.get_codes_of_input(inputs)
code_results.update(r)
# monitors
if len(self._mon_items):
mon, r = self.runner.get_codes_of_monitor(self._mon_items, run_length=mon_length)
code_results.update(r)
self.mon.clear()
self.mon.update(mon)
# steps
r = self.runner.get_codes_of_steps()
code_results.update(r)

# 3. code calls
# -------------
calls = self.runner.merge_codes(code_results)
if self._cls_type == constants.SYN_CONN_TYPE:
if self.delay_len > 1:
calls.append(f'{self.name}.ST._update_delay_indices()')

return calls

def run(self, duration, inputs=(), report=False, report_percent=0.1):
"""The running function.

Parameters
----------
duration : float, int, tuple, list
The running duration.
inputs : list, tuple
The model inputs with the format of ``[(key, value [operation])]``.
report : bool
Whether report the running progress.
report_percent : float
The percent of progress to report.
"""

# times
# ------
if isinstance(duration, (int, float)):
start, end = 0., duration
elif isinstance(duration, (tuple, list)):
assert len(duration) == 2, 'Only support duration setting with the format of "(start, end)".'
start, end = duration
else:
raise ValueError(f'Unknown duration type: {type(duration)}')
times = np.asarray(np.arange(start, end, profile.get_dt()), dtype=np.float_)
run_length = times.shape[0]

# check inputs
# -------------
if not isinstance(inputs, (tuple, list)):
raise errors.ModelUseError('"inputs" must be a tuple/list.')
if len(inputs) and not isinstance(inputs[0], (list, tuple)):
if isinstance(inputs[0], str):
inputs = [inputs]
else:
raise errors.ModelUseError('Unknown input structure, only support inputs '
'with format of "(key, value, [operation])".')
for inp in inputs:
if not 2 <= len(inp) <= 3:
raise errors.ModelUseError('For each target, you must specify "(key, value, [operation])".')
if len(inp) == 3 and inp[2] not in constants.INPUT_OPERATIONS:
raise errors.ModelUseError(f'Input operation only supports '
f'"{list(constants.INPUT_OPERATIONS.keys())}", '
f'not "{inp[2]}".')

# format inputs
# -------------
formatted_inputs = []
for inp in inputs:
# key
if not isinstance(inp[0], str):
raise errors.ModelUseError('For each input, input[0] must be a string '
'to specify variable of the target.')
key = inp[0]
# value and data type
if isinstance(inp[1], (int, float)):
val = inp[1]
data_type = 'fix'
elif isinstance(inp[1], np.ndarray):
val = inp[1]
if val.shape[0] == run_length:
data_type = 'iter'
else:
data_type = 'fix'
else:
raise errors.ModelUseError('For each input, input[1] must be a '
'numerical value to specify input values.')
# operation
if len(inp) == 3:
ops = inp[2]
else:
ops = '+'
# input
format_inp = (key, val, ops, data_type)
formatted_inputs.append(format_inp)

# get step function
# -------------------
lines_of_call = self.build(inputs=formatted_inputs, mon_length=run_length)
code_lines = ['def step_func(_t, _i, _dt):']
code_lines.extend(lines_of_call)
code_scopes = {self.name: self, f"{self.name}_runner": self.runner}
if profile.run_on_gpu():
code_scopes['cuda'] = cuda
func_code = '\n '.join(code_lines)
exec(compile(func_code, '', 'exec'), code_scopes)
step_func = code_scopes['step_func']
if profile.show_format_code():
utils.show_code_str(func_code)
if profile.show_code_scope():
utils.show_code_scope(code_scopes, ['__builtins__', 'step_func'])

# run the model
# -------------
dt = profile.get_dt()
if report:
t0 = time.time()
step_func(_t=times[0], _i=0, _dt=dt)
print('Compilation used {:.4f} s.'.format(time.time() - t0))

print("Start running ...")
report_gap = int(run_length * report_percent)
t0 = time.time()
for run_idx in range(1, run_length):
step_func(_t=times[run_idx], _i=run_idx, _dt=dt)
if (run_idx + 1) % report_gap == 0:
percent = (run_idx + 1) / run_length * 100
print('Run {:.1f}% used {:.3f} s.'.format(percent, time.time() - t0))
print('Simulation is done in {:.3f} s.'.format(time.time() - t0))
else:
for run_idx in range(run_length):
step_func(_t=times[run_idx], _i=run_idx, _dt=dt)

if profile.run_on_gpu():
self.runner.gpu_data_to_cpu()
self.mon['ts'] = times

def get_schedule(self):
"""Get the schedule (running order) of the update functions.

Returns
-------
schedule : list, tuple
The running order of update functions.
"""
return self.runner.get_schedule()

def set_schedule(self, schedule):
"""Set the schedule (running order) of the update functions.

For example, if the ``self.model`` has two step functions: `step1`, `step2`.
Then, you can set the shedule by using:

>>> set_schedule(['input', 'step1', 'step2', 'monitor'])
"""
self.runner.set_schedule(schedule)

@property
def requires(self):
return self.model.requires

+ 0
- 28
brainpy/core/constants.py View File

@@ -1,28 +0,0 @@
# -*- coding: utf-8 -*-

# argument keywords
KW_DT = '_dt'
KW_T = '_t'
KW_I = '_i'
ARG_KEYWORDS = ['_dt', '_t', '_i', '_obj_i', '_pre_i', '_post_i']

# name of the neuron group
NEU_GROUP_TYPE = 'NeuGroup'

# name of the synapse connection
SYN_CONN_TYPE = 'SynConn'

# input operations
INPUT_OPERATIONS = {'-': 'sub',
'+': 'add',
'x': 'mul',
'*': 'mul',
'/': 'div',
'=': 'assign'}

# model mode
SCALAR_MODE = 'scalar'
VECTOR_MODE = 'vector'
MATRIX_MODE = 'matrix'



+ 0
- 331
brainpy/core/network.py View File

@@ -1,331 +0,0 @@
# -*- coding: utf-8 -*-

import time
from collections import OrderedDict

import numpy as np
from numba import cuda

from . import base
from . import constants
from . import utils
from .. import errors
from .. import profile

__all__ = [
'Network',
]


class Network(object):
"""The main simulation controller in ``BrainPy``.

``Network`` handles the running of a simulation. It contains a set of
objects that are added with `add()`. The `run()` method
actually runs the simulation. The main loop runs according to user add
orders. The objects in the `Network` are accessible via their names, e.g.
`net.name` would return the `object` (including neurons and synapses).
"""

def __init__(self, *args, mode=None, **kwargs):
# record the current step
self.t_start = 0.
self.t_end = 0.

# store all objects
self._all_objects = OrderedDict()
self.add(*args, **kwargs)

# store the step function
self._step_func = None

if isinstance(mode, str):
print('The "repeat" mode of the network is set to the default. '
'After version 0.4.0, "mode" setting will be removed.')

def _add_obj(self, obj, name=None):
# 1. check object type
if not isinstance(obj, base.Ensemble):
raise ValueError(f'Unknown object type "{type(obj)}". Network '
f'only supports NeuGroup and SynConn.')
# 2. check object name
name = obj.name if name is None else name
if name in self._all_objects:
raise KeyError(f'Name "{name}" has been used in the network, '
f'please change another name.')
self._all_objects[name] = obj
# 3. add object to the network
setattr(self, name, obj)
if obj.name != name:
setattr(self, obj.name, obj)

def add(self, *args, **kwargs):
"""Add object (neurons or synapses) to the network.

Parameters
----------
args
The nameless objects.
kwargs
The named objects, which can be accessed by `net.xxx`
(xxx is the name of the object).
"""

for obj in args:
self._add_obj(obj)
for name, obj in kwargs.items():
self._add_obj(obj, name)

def format_inputs(self, inputs, run_length):
"""Format the user defined inputs.

Parameters
----------
inputs : tuple
The inputs.
run_length : int
The running length.

Returns
-------
formatted_input : dict
The formatted input.
"""

# 1. format the inputs to standard
# formats and check the inputs
if not isinstance(inputs, (tuple, list)):
raise errors.ModelUseError('"inputs" must be a tuple/list.')
if len(inputs) > 0 and not isinstance(inputs[0], (list, tuple)):
if isinstance(inputs[0], base.Ensemble):
inputs = [inputs]
else:
raise errors.ModelUseError(
'Unknown input structure. Only supports "(target, key, value, [operation])".')
for inp in inputs:
if not 3 <= len(inp) <= 4:
raise errors.ModelUseError('For each target, you must specify "(target, key, value, [operation])".')
if len(inp) == 4:
if inp[3] not in constants.INPUT_OPERATIONS:
raise errors.ModelUseError(f'Input operation only support '
f'"{list(constants.INPUT_OPERATIONS.keys())}", '
f'not "{inp[3]}".')

# 2. format inputs
formatted_inputs = {}
for inp in inputs:
# target
if isinstance(inp[0], str):
target = getattr(self, inp[0]).name
elif isinstance(inp[0], base.Ensemble):
target = inp[0].name
else:
raise KeyError(f'Unknown input target: {str(inp[0])}')

# key
if not isinstance(inp[1], str):
raise errors.ModelUseError('For each input, input[1] must be a string '
'to specify variable of the target.')
key = inp[1]

# value and data type
if isinstance(inp[2], (int, float)):
val = inp[2]
data_type = 'fix'
elif isinstance(inp[2], np.ndarray):
val = inp[2]
if val.shape[0] == run_length:
data_type = 'iter'
else:
data_type = 'fix'
else:
raise errors.ModelUseError(f'For each input, input[2] must be a numerical value to '
f'specify input values, but we get a {type(inp)}')

# operation
if len(inp) == 4:
ops = inp[3]
else:
ops = '+'

# final result
if target not in formatted_inputs:
formatted_inputs[target] = []
format_inp = (key, val, ops, data_type)
formatted_inputs[target].append(format_inp)
return formatted_inputs

def build(self, run_length, inputs=()):
"""Build the network.

Parameters
----------
run_length : int
The running length.
inputs : tuple, list
The user-defined inputs.

Returns
-------
step_func : callable
The step function.
"""
if not isinstance(run_length, int):
raise errors.ModelUseError(f'The running length must be an int, but we get {run_length}')

# inputs
format_inputs = self.format_inputs(inputs, run_length)

# codes for step function
code_scopes = {}
code_lines = ['# network step function\ndef step_func(_t, _i, _dt):']
for obj in self._all_objects.values():
if profile.run_on_gpu():
if obj.model.mode != constants.SCALAR_MODE:
raise errors.ModelUseError(f'GPU mode only support scalar-based mode. '
f'But {obj.model} is a {obj.model.mode}-based model.')
code_scopes[obj.name] = obj
code_scopes[f'{obj.name}_runner'] = obj.runner
lines_of_call = obj.build(inputs=format_inputs.get(obj.name, None), mon_length=run_length)
code_lines.extend(lines_of_call)
if profile.run_on_gpu():
code_scopes['cuda'] = cuda
func_code = '\n '.join(code_lines)

# compile the step function
exec(compile(func_code, '', 'exec'), code_scopes)
step_func = code_scopes['step_func']

# show
if profile.show_format_code():
utils.show_code_str(func_code.replace('def ', f'def network_'))
if profile.show_code_scope():
utils.show_code_scope(code_scopes, ['__builtins__', 'step_func'])

return step_func

def run(self, duration, inputs=(), report=False, report_percent=0.1,
data_to_host=False, verbose=True):
"""Run the simulation for the given duration.

This function provides the most convenient way to run the network.
For example:

Parameters
----------
duration : int, float, tuple, list
The amount of simulation time to run for.
inputs : list, tuple
The receivers, external inputs and durations.
report : bool
Report the progress of the simulation.
report_percent : float
The speed to report simulation progress.
data_to_host : bool
Transfer the gpu data to cpu. Available in CUDA backend.
verbose : bool
Show the error information.
"""
# check the duration
# ------------------
if isinstance(duration, (int, float)):
start, end = 0, duration
elif isinstance(duration, (tuple, list)):
if len(duration) != 2:
raise errors.ModelUseError('Only support duration with the format of "(start, end)".')
start, end = duration
else:
raise ValueError(f'Unknown duration type: {type(duration)}')
self.t_start, self.t_end = start, end
dt = profile.get_dt()
ts = np.asarray(np.arange(start, end, dt), dtype=np.float_)
run_length = ts.shape[0]

if self._step_func is None:
# initialize the function
# -----------------------
self._step_func = self.build(run_length, inputs)
else:
# check and reset inputs
# ----------------------
input_keep_same = True
formatted_inputs = self.format_inputs(inputs, run_length)
for obj in self._all_objects.values():
obj_name = obj.name
obj_inputs = obj.runner._inputs
onj_input_keys = list(obj_inputs.keys())
if obj_name in formatted_inputs:
current_inputs = formatted_inputs[obj_name]
else:
current_inputs = []
for key, val, ops, data_type in current_inputs:
if np.shape(obj_inputs[key][0]) != np.shape(val):
if verbose:
print(f'The current "{key}" input shape {np.shape(val)} is different '
f'from the last input shape {np.shape(obj_inputs[key][0])}. ')
input_keep_same = False
if obj_inputs[key][1] != ops:
if verbose:
print(f'The current "{key}" input operation "{ops}" is different '
f'from the last operation "{obj_inputs[key][1]}". ')
input_keep_same = False
obj.runner.set_data(f'{key.replace(".", "_")}_inp', val)
if key in onj_input_keys:
onj_input_keys.remove(key)
else:
input_keep_same = False
if verbose:
print(f'The input to a new key "{key}" in {obj_name}.')
if len(onj_input_keys):
input_keep_same = False
if verbose:
print(f'The inputs of {onj_input_keys} in {obj_name} are not provided.')
if input_keep_same:
# reset monitors
# --------------
for obj in self._all_objects.values():
obj.reshape_mon(run_length)
else:
if verbose:
print('The network will be rebuild.')
self._step_func = self.build(run_length, inputs)

dt = self.dt
if report:
# Run the model with progress report
# ----------------------------------
t0 = time.time()
self._step_func(_t=ts[0], _i=0, _dt=dt)
print('Compilation used {:.4f} s.'.format(time.time() - t0))

print("Start running ...")
report_gap = int(run_length * report_percent)
t0 = time.time()
for run_idx in range(1, run_length):
self._step_func(_t=ts[run_idx], _i=run_idx, _dt=dt)
if (run_idx + 1) % report_gap == 0:
percent = (run_idx + 1) / run_length * 100
print('Run {:.1f}% used {:.3f} s.'.format(percent, time.time() - t0))
print('Simulation is done in {:.3f} s.'.format(time.time() - t0))
else:
# Run the model
# -------------
for run_idx in range(run_length):
self._step_func(_t=ts[run_idx], _i=run_idx, _dt=dt)

# format monitor
# --------------
for obj in self._all_objects.values():
obj.mon['ts'] = self.ts
if data_to_host and profile.run_on_gpu():
obj.runner.gpu_data_to_cpu()

@property
def ts(self):
"""Get the time points of the network.
"""
return np.array(np.arange(self.t_start, self.t_end, self.dt), dtype=np.float_)

@property
def dt(self):
return profile.get_dt()

+ 0
- 181
brainpy/core/neurons.py View File

@@ -1,181 +0,0 @@
# -*- coding: utf-8 -*-

import numpy as np

from . import base
from . import constants
from . import utils
from .. import errors

__all__ = [
'NeuType',
'NeuGroup',
'NeuSubGroup',
]

_NEU_GROUP_NO = 0


class NeuType(base.ObjType):
"""Abstract Neuron Type.

It can be defined based on a group of neurons or a single neuron.
"""

def __init__(self, name, ST, steps, mode='vector', requires=None, hand_overs=None, ):
if mode not in [constants.SCALAR_MODE, constants.VECTOR_MODE]:
raise errors.ModelDefError('NeuType only support "scalar" or "vector".')

super(NeuType, self).__init__(
ST=ST,
requires=requires,
steps=steps,
name=name,
mode=mode,
hand_overs=hand_overs)


class NeuGroup(base.Ensemble):
"""Neuron Group.

Parameters
----------
model : NeuType
The instantiated neuron type model.
geometry : int, tuple
The neuron group geometry.
pars_update : dict
Parameters to update.
monitors : list, tuple
Variables to monitor.
name : str
The name of the neuron group.
"""

def __init__(self, model, geometry, monitors=None, name=None, satisfies=None, pars_update=None, ):
# name
# -----
if name is None:
global _NEU_GROUP_NO
name = f'NeuGroup{_NEU_GROUP_NO}'
_NEU_GROUP_NO += 1
else:
name = name

# num and geometry
# -----------------
if isinstance(geometry, (int, float)):
geometry = num = int(geometry)
self.indices = np.asarray(np.arange(int(geometry)), dtype=np.int_)
elif isinstance(geometry, (tuple, list)):
if len(geometry) == 1:
geometry = num = geometry[0]
indices = np.arange(num)
elif len(geometry) == 2:
height, width = geometry[0], geometry[1]
num = height * width
indices = np.arange(num).reshape((height, width))
else:
raise errors.ModelUseError('Do not support 3+ dimensional networks.')
self.indices = np.asarray(indices, dtype=np.int_)
else:
raise ValueError()
self.geometry = geometry
self.size = np.size(self.indices)

# model
# ------
try:
assert isinstance(model, NeuType)
except AssertionError:
raise errors.ModelUseError(f'{NeuGroup.__name__} receives an '
f'instance of {NeuType.__name__}, '
f'not {type(model).__name__}.')

# initialize
# ----------
super(NeuGroup, self).__init__(model=model,
pars_update=pars_update,
name=name,
num=num,
monitors=monitors,
cls_type=constants.NEU_GROUP_TYPE,
satisfies=satisfies)

# ST
# --
self.ST = self.model.ST.make_copy(num)

def __getitem__(self, item):
"""Return a subset of neuron group.

Parameters
----------
item : slice, int, tuple of slice

Returns
-------
sub_group : NeuSubGroup
The subset of the neuron group.
"""

if isinstance(item, int):
try:
assert item < self.num
except AssertionError:
raise errors.ModelUseError(f'Index error, because the maximum number of neurons'
f'is {self.num}, but got "item={item}".')
d1_start, d1_end, d1_step = item, item + 1, 1
utils.check_slice(d1_start, d1_end, self.num)
indices = self.indices[d1_start:d1_end:d1_step]
elif isinstance(item, slice):
d1_start, d1_end, d1_step = item.indices(self.num)
utils.check_slice(d1_start, d1_end, self.num)
indices = self.indices[d1_start:d1_end:d1_step]
elif isinstance(item, tuple):
if not isinstance(self.geometry, (tuple, list)):
raise errors.ModelUseError(f'{self.name} has a 1D geometry, cannot use a tuple of slice.')
if len(item) != 2:
raise errors.ModelUseError(f'Only support 2D network, cannot make {len(item)}D slice.')

if isinstance(item[0], slice):
d1_start, d1_end, d1_step = item[0].indices(self.geometry[0])
elif isinstance(item[0], int):
d1_start, d1_end, d1_step = item[0], item[0] + 1, 1
else:
raise errors.ModelUseError("Only support slicing syntax or a single index.")
utils.check_slice(d1_start, d1_end, self.geometry[0])

if isinstance(item[1], slice):
d2_start, d2_end, d2_step = item[1].indices(self.geometry[1])
elif isinstance(item[1], int):
d2_start, d2_end, d2_step = item[1], item[1] + 1, 1
else:
raise errors.ModelUseError("Only support slicing syntax or a single index.")
utils.check_slice(d1_start, d1_end, self.geometry[1])

indices = self.indices[d1_start:d1_end:d1_step, d2_start:d2_end:d2_step]
else:
raise errors.ModelUseError('Subgroups can only be constructed using slicing syntax, '
'a single index, or an array of contiguous indices.')

return NeuSubGroup(source=self, indices=indices)


class NeuSubGroup(object):
"""Subset of a `NeuGroup`.
"""

def __init__(self, source, indices):
if not isinstance(source, NeuGroup):
raise errors.ModelUseError('NeuSubGroup only support an instance of NeuGroup.')

self.source = source
self.indices = indices
self.num = np.size(indices)

def __getattr__(self, item):
if item in ['source', 'indices', 'num']:
return getattr(self, item)
else:
return getattr(self.source, item)

+ 0
- 1256
brainpy/core/runner.py View File

@@ -1,1256 +0,0 @@
# -*- coding: utf-8 -*-

import ast
import inspect
import math
import re

import numba
import numpy as np
from numba import cuda
from numba.cuda.random import create_xoroshiro128p_states
from numba.cuda.random import xoroshiro128p_normal_float64

from . import constants
from . import types
from . import utils
from .. import errors
from .. import integration
from .. import profile
from .. import tools
from ..tools import NoiseHandler

__all__ = [
'Runner',
'TrajectoryRunner',
]



class Runner(object):
"""Basic runner class.

Parameters
----------
ensemble : NeuGroup, SynConn
The ensemble of the models.
"""

def __init__(self, ensemble):
# ensemble: NeuGroup / SynConn
self.ensemble = ensemble
# ensemble model
self._model = ensemble.model
# ensemble name
self._name = ensemble.name
# ensemble parameters
self._pars = ensemble.pars
# model delay keys
self._delay_keys = ensemble.model._delay_keys
# model step functions
self._steps = ensemble.model.steps
self._step_names = ensemble.model.step_names
# model update schedule
self._schedule = ['input'] + ensemble.model.step_names + ['monitor']
self._inputs = {}
self.gpu_data = {}

def check_attr(self, attr):
if not hasattr(self, attr):
raise errors.ModelUseError(f'Model "{self._name}" doesn\'t have "{attr}" attribute", '
f'and "{self._name}.ST" doesn\'t have "{attr}" field.')

def get_codes_of_input(self, key_val_ops_types):
"""Format the code of external input.

Parameters
----------
key_val_ops_types : list, tuple
The inputs.

Returns
-------
code : dict
The formatted code.
"""
if len(key_val_ops_types) <= 0:
raise errors.ModelUseError(f'{self._name} has no input, cannot call this function.')

# check datatype of the input
# ----------------------------
has_iter = False
all_inputs = set()
for key, val, ops, t in key_val_ops_types:
if t not in ['iter', 'fix']:
raise errors.ModelUseError('Only support inputs of "iter" and "fix" types.')
if t == 'iter':
has_iter = True
if key in all_inputs:
raise errors.ModelUseError('Only support assignment for each key once.')
else:
self._inputs[key] = (val, ops, t)
all_inputs.add(key)

# check data operations
# ----------------------
for _, _, ops, _ in key_val_ops_types:
if ops not in constants.INPUT_OPERATIONS:
raise errors.ModelUseError(
f'Only support five input operations: {list(constants.INPUT_OPERATIONS.keys())}')

# generate code of input function
# --------------------------------
if profile.run_on_cpu():
code_scope = {self._name: self.ensemble, f'{self._name}_runner': self}
code_args, code_arg2call, code_lines = set(), {}, []
if has_iter:
code_args.add('_i')
code_arg2call['_i'] = '_i'

input_idx = 0
for key, val, ops, data_type in key_val_ops_types:
# get the left side #
attr_item = key.split('.')
if len(attr_item) == 1 and (attr_item[0] not in self.ensemble.ST):
# if "item" is the model attribute
attr, item = attr_item[0], ''
target = getattr(self.ensemble, attr)
self.check_attr(attr)
if not isinstance(target, np.ndarray):
raise errors.ModelUseError(f'BrainPy only support input to arrays.')
left = attr
code_args.add(left)
code_arg2call[left] = f'{self._name}.{attr}'
else:
if len(attr_item) == 1:
attr, item = 'ST', attr_item[0]
elif len(attr_item) == 2:
attr, item = attr_item[0], attr_item[1]
else:
raise errors.ModelUseError(f'Unknown target : {key}.')
data = getattr(self.ensemble, attr)
if item not in data:
raise errors.ModelUseError(f'"{self._name}.{attr}" doesn\'t have "{item}" field.')
idx = data['_var2idx'][item]
left = f'{attr}[{idx}]'
code_args.add(attr)
code_arg2call[attr] = f'{self._name}.{attr}["_data"]'

# get the right side #
right = f'{key.replace(".", "_")}_inp'
code_args.add(right)
code_arg2call[right] = f'{self._name}_runner.{right}'
self.set_data(right, val)
if data_type == 'iter':
right = right + '[_i]'
if np.ndim(val) > 1:
pass
input_idx += 1

# final code line #
if ops == '=':
code_lines.append(f"{left} = {right}")
else:
code_lines.append(f"{left} {ops}= {right}")

# final code
# ----------
code_lines.insert(0, f'# "input" step function of {self._name}')
code_lines.append('\n')

# compile function
code_to_compile = [f'def input_step({tools.func_call(code_args)}):'] + code_lines
func_code = '\n '.join(code_to_compile)
exec(compile(func_code, '', 'exec'), code_scope)
input_step = code_scope['input_step']
# if profile.is_jit():
# input_step = tools.jit(input_step)
self.input_step = input_step
if not profile.is_merge_steps():
if profile.show_format_code():
utils.show_code_str(func_code.replace('def ', f'def {self._name}_'))
if profile.show_code_scope():
utils.show_code_scope(code_scope, ['__builtins__', 'input_step'])

# format function call
arg2call = [code_arg2call[arg] for arg in sorted(list(code_args))]
func_call = f'{self._name}_runner.input_step({tools.func_call(arg2call)})'

return {'input': {'scopes': code_scope,
'args': code_args,
'arg2calls': code_arg2call,
'codes': code_lines,
'call': func_call}}

else:
input_idx = 0
results = {}
for key, val, ops, data_type in key_val_ops_types:
code_scope = {self._name: self.ensemble, f'{self._name}_runner': self, 'cuda': cuda}
code_args, code_arg2call, code_lines = set(), {}, []
if has_iter:
code_args.add('_i')
code_arg2call['_i'] = '_i'

attr_item = key.split('.')
if len(attr_item) == 1 and (attr_item[0] not in self.ensemble.ST):
# if "item" is the model attribute
attr, item = attr_item[0], ''
self.check_attr(attr)
target = getattr(self.ensemble, attr)
if not isinstance(target, np.ndarray):
raise errors.ModelUseError(f'BrainPy only supports input to arrays.')
# get the left side
left = f'{attr}[cuda_i]'
self.set_gpu_data(f'{attr}_cuda', target)
else:
# if "item" is the ObjState
if len(attr_item) == 1:
attr, item = 'ST', attr_item[0]
elif len(attr_item) == 2:
attr, item = attr_item[0], attr_item[1]
else:
raise errors.ModelUseError(f'Unknown target : {key}.')
data = getattr(self.ensemble, attr)
if item not in data:
raise errors.ModelUseError(f'"{self._name}.{attr}" doesn\'t have "{item}" field.')
# get the left side
target = data[item]
idx = data['_var2idx'][item]
left = f'{attr}[{idx}, cuda_i]'
self.set_gpu_data(f'{attr}_cuda', data)
code_args.add(f'{attr}')
code_arg2call[f'{attr}'] = f'{self._name}_runner.{attr}_cuda'

# get the right side #
right = f'{key.replace(".", "_")}_inp'
self.set_data(right, val)
code_args.add(right)
code_arg2call[right] = f'{self._name}_runner.{right}'

# check data type
iter_along_time = data_type == 'iter'
if np.isscalar(val):
iter_along_data = False
else:
if iter_along_time:
if np.isscalar(val[0]):
iter_along_data = False
else:
assert len(val[0]) == len(target)
iter_along_data = True
else:
assert len(val) == len(target)
iter_along_data = True
if iter_along_time and iter_along_data:
right = right + '[_i, cuda_i]'
elif iter_along_time:
right = right + '[_i]'
elif iter_along_data:
right = right + '[cuda_i]'
else:
right = right

# final code line
if ops == '=':
code_lines.append(f"{left} = {right}")
else:
code_lines.append(f"{left} {ops}= {right}")
code_lines = [' ' + line for line in code_lines]
code_lines.insert(0, f'if cuda_i < {len(target)}:')

# final code
func_name = f'input_of_{attr}_{item}'
code_to_compile = [f'# "input" of {self._name}.{attr}.{item}',
f'def {func_name}({tools.func_call(code_args)}):',
f' cuda_i = cuda.grid(1)']
code_to_compile += [f' {line}' for line in code_lines]

# compile function
func_code = '\n'.join(code_to_compile)
exec(compile(func_code, '', 'exec'), code_scope)
step_func = code_scope[func_name]
step_func = cuda.jit(step_func)
setattr(self, func_name, step_func)
if not profile.is_merge_steps():
if profile.show_format_code():
utils.show_code_str(func_code.replace('def ', f'def {self._name}_'))
if profile.show_code_scope():
utils.show_code_scope(code_scope, ['__builtins__', 'input_step'])

# format function call
if len(target) <= profile.get_num_thread_gpu():
num_thread = len(target)
num_block = 1
else:
num_thread = profile.get_num_thread_gpu()
num_block = math.ceil(len(target) / profile.get_num_thread_gpu())
arg2call = [code_arg2call[arg] for arg in sorted(list(code_args))]
func_call = f'{self._name}_runner.{func_name}[{num_block}, {num_thread}]({tools.func_call(arg2call)})'

# function result
results[f'input-{input_idx}'] = {'scopes': code_scope,
'args': code_args,
'arg2calls': code_arg2call,
'codes': code_lines,
'call': func_call,
'num_data': len(target)}

# iteration
input_idx += 1

return results

def get_codes_of_monitor(self, mon_vars, run_length):
"""Get the code of the monitors.

Parameters
----------
mon_vars : tuple, list
The variables to monitor.
run_length

Returns
-------
code : dict
The formatted code.
"""
if len(mon_vars) <= 0:
raise errors.ModelUseError(f'{self._name} has no monitor, cannot call this function.')

# check indices #
for key, indices in mon_vars:
if indices is not None:
if isinstance(indices, list):
if not isinstance(indices[0], int):
raise errors.ModelUseError('Monitor index only supports list [int] or 1D array.')
elif isinstance(indices, np.ndarray):
if np.ndim(indices) != 1:
raise errors.ModelUseError('Monitor index only supports list [int] or 1D array.')
else:
raise errors.ModelUseError(f'Unknown monitor index type: {type(indices)}.')

if profile.run_on_cpu():
# monitor
mon = tools.DictPlus()

code_scope = {self._name: self.ensemble, f'{self._name}_runner': self}
code_args, code_arg2call, code_lines = set(), {}, []

# generate code of monitor function
# ---------------------------------
mon_idx = 0
for key, indices in mon_vars:
if indices is not None:
indices = np.asarray(indices)
attr_item = key.split('.')

# get the code line #
if (len(attr_item) == 1) and (attr_item[0] not in self.ensemble.ST):
attr = attr_item[0]
self.check_attr(attr)
data = getattr(self.ensemble, attr)
if not isinstance(data, np.ndarray):
assert errors.ModelUseError(f'BrainPy only supports monitor of arrays.')
shape = data.shape
mon_name = f'mon_{attr}'
target_name = attr
if indices is None:
line = f'{mon_name}[_i] = {target_name}'
else:
idx_name = f'idx{mon_idx}_{attr}'
line = f'{mon_name}[_i] = {target_name}[{idx_name}]'
code_scope[idx_name] = indices
code_args.add(mon_name)
code_arg2call[mon_name] = f'{self._name}.mon["{key}"]'
code_args.add(target_name)
code_arg2call[target_name] = f'{self._name}.{attr}'
else:
if len(attr_item) == 1:
item, attr = attr_item[0], 'ST'
elif len(attr_item) == 2:
attr, item = attr_item
else:
raise errors.ModelUseError(f'Unknown target : {key}.')
data = getattr(self.ensemble, attr)
shape = data[item].shape
idx = data['_var2idx'][item]
mon_name = f'mon_{attr}_{item}'
target_name = attr
if indices is None:
line = f'{mon_name}[_i] = {target_name}[{idx}]'
else:
idx_name = f'idx{mon_idx}_{attr}_{item}'
line = f'{mon_name}[_i] = {target_name}[{idx}][{idx_name}]'
code_scope[idx_name] = indices
code_args.add(mon_name)
code_arg2call[mon_name] = f'{self._name}.mon["{key}"]'
code_args.add(target_name)
code_arg2call[target_name] = f'{self._name}.{attr}["_data"]'
mon_idx += 1

# initialize monitor array #
key = key.replace('.', '_')
if indices is None:
mon[key] = np.zeros((run_length,) + shape, dtype=np.float_)
else:
mon[key] = np.zeros((run_length, len(indices)) + shape[1:], dtype=np.float_)

# add line #
code_lines.append(line)

# final code
# ----------
code_lines.insert(0, f'# "monitor" step function of {self._name}')
code_lines.append('\n')
code_args.add('_i')
code_arg2call['_i'] = '_i'

# compile function
code_to_compile = [f'def monitor_step({tools.func_call(code_args)}):'] + code_lines
func_code = '\n '.join(code_to_compile)

if not profile.is_merge_steps():
if profile.show_format_code():
utils.show_code_str(func_code.replace('def ', f'def {self._name}_'))
if profile.show_code_scope():
utils.show_code_scope(code_scope, ('__builtins__', 'monitor_step'))

exec(compile(func_code, '', 'exec'), code_scope)
monitor_step = code_scope['monitor_step']
# if profile.is_jit():
# monitor_step = tools.jit(monitor_step)
self.monitor_step = monitor_step

# format function call
arg2call = [code_arg2call[arg] for arg in sorted(list(code_args))]
func_call = f'{self._name}_runner.monitor_step({tools.func_call(arg2call)})'

return mon, {'monitor': {'scopes': code_scope,
'args': code_args,
'arg2calls': code_arg2call,
'codes': code_lines,
'call': func_call}}

else:
results = {}
mon = tools.DictPlus()

# generate code of monitor function
# ---------------------------------
mon_idx = 0
for key, indices in mon_vars:
if indices is not None:
indices = np.asarray(indices)
code_scope = {self._name: self.ensemble, f'{self._name}_runner': self}
code_args, code_arg2call, code_lines = set(), {}, []

attr_item = key.split('.')
key = key.replace(".", "_")
# get the code line #
if (len(attr_item) == 1) and (attr_item[0] not in self.ensemble.ST):
attr, item = attr_item[0], ''
self.check_attr(attr)
if not isinstance(getattr(self.ensemble, attr), np.ndarray):
assert errors.ModelUseError(f'BrainPy only supports monitor of arrays.')
data = getattr(self.ensemble, attr)
shape = data.shape
mon_name = f'mon_{attr}'
target_name = f'{attr}_cuda'
if indices is None:
num_data = shape[0]
line = f'{mon_name}[_i, cuda_i] = {target_name}[cuda_i]'
else:
num_data = len(indices)
idx_name = f'idx{mon_idx}_{attr}'
code_lines.append(f'mon_idx = {idx_name}[cuda_i]')
line = f'{mon_name}[_i, cuda_i] = {target_name}[mon_idx]'
code_scope[idx_name] = cuda.to_device(indices)
code_args.add(mon_name)
code_arg2call[mon_name] = f'{self._name}_runner.mon_{key}_cuda'
self.set_gpu_data(f'{attr}_cuda', data)
code_args.add(target_name)
code_arg2call[target_name] = f'{self._name}_runner.{attr}_cuda'
else:
if len(attr_item) == 1:
item, attr = attr_item[0], 'ST'
elif len(attr_item) == 2:
attr, item = attr_item
else:
raise errors.ModelUseError(f'Unknown target : {key}.')
data = getattr(self.ensemble, attr)
shape = data[item].shape
idx = getattr(self.ensemble, attr)['_var2idx'][item]
mon_name = f'mon_{attr}_{item}'
target_name = attr
if indices is None:
num_data = shape[0]
line = f'{mon_name}[_i, cuda_i] = {target_name}[{idx}, cuda_i]'
else:
num_data = len(indices)
idx_name = f'idx{mon_idx}_{attr}_{item}'
code_lines.append(f'mon_idx = {idx_name}[cuda_i]')
line = f'{mon_name}[_i, cuda_i] = {target_name}[{idx}, mon_idx]'
code_scope[idx_name] = cuda.to_device(indices)
code_args.add(mon_name)
code_arg2call[mon_name] = f'{self._name}_runner.mon_{key}_cuda'
self.set_gpu_data(f'{attr}_cuda', data)
code_args.add(target_name)
code_arg2call[target_name] = f'{self._name}_runner.{attr}_cuda'

# initialize monitor array #
if indices is None:
mon[key] = np.zeros((run_length,) + shape, dtype=np.float_)
else:
mon[key] = np.zeros((run_length, num_data) + shape[1:], dtype=np.float_)
self.set_gpu_data(f'mon_{key}_cuda', mon[key])

# add line #
code_args.add('_i')
code_arg2call['_i'] = '_i'
code_scope['cuda'] = cuda

# final code
# ----------
code_lines.append(line)
code_lines = [' ' + line for line in code_lines]
code_lines.insert(0, f'if cuda_i < {num_data}:')

# compile function
func_name = f'monitor_of_{attr}_{item}'
code_to_compile = [f'# "monitor" of {self._name}.{attr}.{item}',
f'def {func_name}({tools.func_call(code_args)}):',
f' cuda_i = cuda.grid(1)']
code_to_compile += [f' {line}' for line in code_lines]
func_code = '\n'.join(code_to_compile)
exec(compile(func_code, '', 'exec'), code_scope)
monitor_step = code_scope[func_name]
monitor_step = cuda.jit(monitor_step)
setattr(self, func_name, monitor_step)

if not profile.is_merge_steps():
if profile.show_format_code():
utils.show_code_str(func_code.replace('def ', f'def {self._name}_'))
if profile.show_code_scope():
utils.show_code_scope(code_scope, ('__builtins__', 'monitor_step'))

# format function call
if num_data <= profile.get_num_thread_gpu():
num_thread = num_data
num_block = 1
else:
num_thread = profile.get_num_thread_gpu()
num_block = math.ceil(num_data / profile.get_num_thread_gpu())
arg2call = [code_arg2call[arg] for arg in sorted(list(code_args))]
func_call = f'{self._name}_runner.{func_name}[{num_block}, {num_thread}]({tools.func_call(arg2call)})'

results[f'monitor-{mon_idx}'] = {'scopes': code_scope,
'args': code_args,
'arg2calls': code_arg2call,
'codes': code_lines,
'call': func_call,
'num_data': num_data}

mon_idx += 1

return mon, results

def get_codes_of_steps(self):
"""Get the code of user defined update steps.

Returns
-------
code : dict
The formatted code.
"""
if self._model.mode == constants.SCALAR_MODE:
return self.step_scalar_model()
else:
return self.step_vector_model()

def format_step_code(self, func_code):
"""Format code of user defined step function.

Parameters
----------
func_code : str
The user defined function codes.
"""
tree = ast.parse(func_code.strip())
formatter = tools.CodeLineFormatter()
formatter.visit(tree)
return formatter

def merge_integrators(self, func):
"""Substitute the user defined integrators into the main step functions.

Parameters
----------
func : callable
The user defined (main) step function.

Returns
-------
results : tuple
The codes and code scope.
"""
# get code and code lines
func_code = tools.deindent(tools.get_main_code(func))
formatter = self.format_step_code(func_code)
code_lines = formatter.lines

# get function scope
vars = inspect.getclosurevars(func)
code_scope = dict(vars.nonlocals)
code_scope.update(vars.globals)
code_scope.update({self._name: self.ensemble})
code_scope.update(formatter.scope)
if len(code_lines) == 0:
return '', code_scope

# code scope update
scope_to_add = {}
scope_to_del = set()
need_add_mapping_scope = False
for k, v in code_scope.items():
if isinstance(v, integration.Integrator):
if profile.is_merge_integrators():
need_add_mapping_scope = True

# locate the integration function
need_replace = False
int_func_name = v.py_func_name
for line_no, line in enumerate(code_lines):
if int_func_name in tools.get_identifiers(line):
need_replace = True
break
if not need_replace:
scope_to_del.add(k)
continue

# get integral function line indent
line_indent = tools.get_line_indent(line)
indent = ' ' * line_indent

# get the replace line and arguments need to replace
new_line, args, kwargs = tools.replace_func(line, int_func_name)
# append code line of argument replacement
func_args = v.diff_eq.func_args
append_lines = [indent + f'_{v.py_func_name}_{func_args[i]} = {args[i]}'
for i in range(len(args))]
for arg in func_args[len(args):]:
append_lines.append(indent + f'_{v.py_func_name}_{arg} = {kwargs[arg]}')

# append numerical integration code lines
append_lines.extend([indent + l for l in v.update_code.split('\n')])
append_lines.append(indent + new_line)

# add appended lines into the main function code lines
code_lines = code_lines[:line_no] + append_lines + code_lines[line_no + 1:]

# get scope variables to delete
scope_to_del.add(k)
for k_, v_ in v.code_scope.items():
if profile.is_jit() and callable(v_):
v_ = tools.numba_func(v_, params=self._pars.updates)
scope_to_add[k_] = v_

else:
if self._model.mode == constants.SCALAR_MODE:
for ks, vs in utils.get_func_scope(v.update_func, include_dispatcher=True).items():
if ks in self._pars.heters:
raise errors.ModelUseError(
f'Heterogeneous parameter "{ks}" is not in step functions, '
f'it will not work. Please set "brainpy.profile.set(merge_integrators=True)" '
f'to try to merge parameter "{ks}" into the step functions.')
if profile.is_jit():
code_scope[k] = tools.numba_func(v.update_func, params=self._pars.updates)

elif type(v).__name__ == 'function':
if profile.is_jit():
code_scope[k] = tools.numba_func(v, params=self._pars.updates)

# update code scope
if need_add_mapping_scope:
code_scope.update(integration.get_mapping_scope())
code_scope.update(scope_to_add)
for k in scope_to_del:
code_scope.pop(k)

# return code lines and code scope
return '\n'.join(code_lines), code_scope, formatter

def step_vector_model(self):
results = dict()

# check whether the model include heterogeneous parameters
delay_keys = self._delay_keys

for func in self._steps:
# information about the function
func_name = func.__name__
stripped_fname = tools.get_func_name(func, replace=True)
func_args = inspect.getfullargspec(func).args

# initialize code namespace
used_args, code_arg2call = set(), {}
func_code, code_scope, formatter = self.merge_integrators(func)
code_scope[f'{self._name}_runner'] = self

# check function code
try:
states = {k: getattr(self.ensemble, k) for k in func_args
if k not in constants.ARG_KEYWORDS and
isinstance(getattr(self.ensemble, k), types.ObjState)}
except AttributeError:
raise errors.ModelUseError(f'Model "{self._name}" does not have all the '
f'required attributes: {func_args}.')
add_args = set()
for i, arg in enumerate(func_args):
used_args.add(arg)
if len(states) == 0:
continue
if arg in states:
st = states[arg]
var2idx = st['_var2idx']

if self.ensemble._is_state_attr(arg):
# Function with "delayed" decorator should use
# ST pulled from the delay queue
if func_name.startswith('_brainpy_delayed_'):
if len(delay_keys):
dout = f'{arg}_dout'
add_args.add(dout)
code_arg2call[dout] = f'{self._name}.{arg}._delay_out'
for st_k in delay_keys:
p = f'{arg}\[([\'"]{st_k}[\'"])\]'
r = f"{arg}[{var2idx['_' + st_k + '_offset']} + {dout}]"
func_code = re.sub(r'' + p, r, func_code)
else:
# Function without "delayed" decorator should push their
# updated ST to the delay queue
if len(delay_keys):
func_code_left = '\n'.join(formatter.lefts)
func_keys = set(re.findall(r'' + arg + r'\[[\'"](\w+)[\'"]\]', func_code_left))
func_delay_keys = func_keys.intersection(delay_keys)
if len(func_delay_keys) > 0:
din = f'{arg}_din'
add_args.add(din)
code_arg2call[din] = f'{self._name}.{arg}._delay_in'
for st_k in func_delay_keys:
right = f'{arg}[{var2idx[st_k]}]'
left = f"{arg}[{var2idx['_' + st_k + '_offset']} + {din}]"
func_code += f'\n{left} = {right}'

# replace key access to index access
for st_k in st._keys:
p = f'{arg}\[([\'"]{st_k}[\'"])\]'
r = f"{arg}[{var2idx[st_k]}]"
func_code = re.sub(r'' + p, r, func_code)

# substitute arguments
code_args = add_args
for arg in used_args:
if arg in constants.ARG_KEYWORDS:
code_arg2call[arg] = arg
else:
if isinstance(getattr(self.ensemble, arg), types.ObjState):
code_arg2call[arg] = f'{self._name}.{arg}["_data"]'
else:
code_arg2call[arg] = f'{self._name}.{arg}'
code_args.add(arg)

# substitute "range" to "numba.prange"
arg_substitute = {}
if ' range' in func_code:
arg_substitute['range'] = 'numba.prange'
code_scope['numba'] = numba
func_code = tools.word_replace(func_code, arg_substitute)

# update code scope
for k in list(code_scope.keys()):
if k in self._pars.updates:
code_scope[k] = self._pars.updates[k]

# handle the "_normal_like_"
func_code = NoiseHandler.normal_pattern.sub(NoiseHandler.vector_replace_f, func_code)
code_scope['numpy'] = np

# final
code_lines = func_code.split('\n')
code_lines.insert(0, f'# "{stripped_fname}" step function of {self._name}')
code_lines.append('\n')

# code to compile
code_to_compile = [f'def {stripped_fname}({tools.func_call(code_args)}):']
code_to_compile += code_lines
func_code = '\n '.join(code_to_compile)
exec(compile(func_code, '', 'exec'), code_scope)
func = code_scope[stripped_fname]
if profile.is_jit():
func = tools.jit(func)
if not profile.is_merge_steps():
if profile.show_format_code():
utils.show_code_str(func_code.replace('def ', f'def {self._name}_'))
if profile.show_code_scope():
utils.show_code_scope(code_scope, ['__builtins__', stripped_fname])

# set the function to the model
setattr(self, stripped_fname, func)
# function call
arg2calls = [code_arg2call[arg] for arg in sorted(list(code_args))]
func_call = f'{self._name}_runner.{stripped_fname}({tools.func_call(arg2calls)})'

results[stripped_fname] = {'scopes': code_scope,
'args': code_args,
'arg2calls': code_arg2call,
'codes': code_lines,
'call': func_call}

return results

def step_scalar_model(self):
results = dict()

# check whether the model include heterogeneous parameters
delay_keys = self._delay_keys
all_heter_pars = set(self._pars.heters.keys())

for i, func in enumerate(self._steps):
func_name = func.__name__

# get necessary code data
# -----------------------
# 1. code arguments
# 2. code argument_to_call
# 3. code lines
# 4. code scope variables
used_args, code_arg2call = set(), {}
func_args = inspect.getfullargspec(func).args
func_code, code_scope, formatter = self.merge_integrators(func)
code_scope[f'{self._name}_runner'] = self
try:
states = {k: getattr(self.ensemble, k) for k in func_args
if k not in constants.ARG_KEYWORDS and
isinstance(getattr(self.ensemble, k), types.ObjState)}
except AttributeError:
raise errors.ModelUseError(f'Model "{self._name}" does not have all the '
f'required attributes: {func_args}.')

# update functions in code scope
# 1. recursively jit the function
# 2. update the function parameters
for k, v in code_scope.items():
if profile.is_jit() and callable(v):
code_scope[k] = tools.numba_func(func=v, params=self._pars.updates)

add_args = set()
# substitute STATE item access to index
for i, arg in enumerate(func_args):
used_args.add(arg)
if len(states) == 0:
continue
if arg not in states:
continue

st = states[arg]
var2idx = st['_var2idx']
if self.ensemble._is_state_attr(arg):
if func_name.startswith('_brainpy_delayed_'):
if len(delay_keys):
dout = f'{arg}_dout'
add_args.add(dout)
code_arg2call[dout] = f'{self._name}.{arg}._delay_out'
# Function with "delayed" decorator should use ST pulled from the delay queue
for st_k in delay_keys:
p = f'{arg}\[([\'"]{st_k}[\'"])\]'
r = f"{arg}[{var2idx['_' + st_k + '_offset']} + {dout}, _obj_i]"
func_code = re.sub(r'' + p, r, func_code)
else:
if len(delay_keys):
# Function without "delayed" decorator should push
# their updated ST to the delay queue
func_code_left = '\n'.join(formatter.lefts)
func_keys = set(re.findall(r'' + arg + r'\[[\'"](\w+)[\'"]\]', func_code_left))
func_delay_keys = func_keys.intersection(delay_keys)
if len(func_delay_keys) > 0:
din = f'{arg}_din'
add_args.add(din)
code_arg2call[din] = f'{self._name}.{arg}._delay_in'
for st_k in func_delay_keys:
right = f'{arg}[{var2idx[st_k]}, _obj_i]'
left = f"{arg}[{var2idx['_' + st_k + '_offset']} + {din}, _obj_i]"
func_code += f'\n{left} = {right}'
for st_k in st._keys:
p = f'{arg}\[([\'"]{st_k}[\'"])\]'
r = f"{arg}[{var2idx[st_k]}, _obj_i]"
func_code = re.sub(r'' + p, r, func_code)
elif arg == 'pre':
# 1. implement the atomic operations for "pre"
if profile.run_on_gpu():
code_lines = func_code.split('\n')
add_cuda = False
line_no = 0
while line_no < len(code_lines):
line = code_lines[line_no]
blank_no = len(line) - len(line.lstrip())
line = line.strip()
if line.startswith('pre'):
pre_transformer = tools.find_atomic_op(line, var2idx)
if pre_transformer.left is not None:
left = pre_transformer.left
right = pre_transformer.right
code_lines[line_no] = ' ' * blank_no + f'cuda.atomic.add({left}, _pre_i, {right})'
add_cuda = True
line_no += 1
if add_cuda:
code_scope['cuda'] = cuda
func_code = '\n'.join(code_lines)
# 2. transform the key access to index access
for st_k in st._keys:
p = f'pre\[([\'"]{st_k}[\'"])\]'
r = f"pre[{var2idx[st_k]}, _pre_i]"
func_code = re.sub(r'' + p, r, func_code)
elif arg == 'post':
# 1. implement the atomic operations for "post"
if profile.run_on_gpu():
code_lines = func_code.split('\n')
add_cuda = False
line_no = 0
while line_no < len(code_lines):
line = code_lines[line_no]
blank_no = len(line) - len(line.lstrip())
line = line.strip()
if line.startswith('post'):
post_transformer = tools.find_atomic_op(line, var2idx)
if post_transformer.left is not None:
left = post_transformer.left
right = post_transformer.right
code_lines[line_no] = ' ' * blank_no + f'cuda.atomic.add({left}, _post_i, {right})'
add_cuda = True
line_no += 1
if add_cuda:
code_scope['cuda'] = cuda
func_code = '\n'.join(code_lines)
# 2. transform the key access to index access
for st_k in st._keys:
p = f'post\[([\'"]{st_k}[\'"])\]'
r = f"post[{var2idx[st_k]}, _post_i]"
func_code = re.sub(r'' + p, r, func_code)
else:
raise ValueError

# get formatted function arguments
# --------------------------------
# 1. For argument in "ARG_KEYWORDS", keep it unchanged
# 2. For argument is an instance of ObjState, get it's cuda data
# 3. For other argument, get it's cuda data
code_args = add_args
for arg in used_args:
if arg in constants.ARG_KEYWORDS:
code_arg2call[arg] = arg
else:
data = getattr(self.ensemble, arg)
if profile.run_on_cpu():
if isinstance(data, types.ObjState):
code_arg2call[arg] = f'{self._name}.{arg}["_data"]'
else:
code_arg2call[arg] = f'{self._name}.{arg}'
else:
if isinstance(data, types.ObjState):
code_arg2call[arg] = f'{self._name}_runner.{arg}_cuda'
else:
code_arg2call[arg] = f'{self._name}_runner.{arg}_cuda'
self.set_gpu_data(f'{arg}_cuda', data)
code_args.add(arg)

# add the for loop in the start of the main code
has_pre = 'pre' in func_args
has_post = 'post' in func_args
if profile.run_on_cpu():
code_lines = [f'for _obj_i in numba.prange({self.ensemble.num}):']
code_scope['numba'] = numba
else:
code_lines = [f'_obj_i = cuda.grid(1)',
f'if _obj_i < {self.ensemble.num}:']
code_scope['cuda'] = cuda

if has_pre:
code_args.add(f'pre_ids')
code_arg2call[f'pre_ids'] = f'{self._name}_runner.pre_ids'
code_lines.append(f' _pre_i = pre_ids[_obj_i]')
self.set_data('pre_ids', getattr(self.ensemble, 'pre_ids'))
if has_post:
code_args.add(f'post_ids')
code_arg2call[f'post_ids'] = f'{self._name}_runner.post_ids'
code_lines.append(f' _post_i = post_ids[_obj_i]')
self.set_data('post_ids', getattr(self.ensemble, 'post_ids'))

# substitute heterogeneous parameter "p" to "p[_obj_i]"
# ------------------------------------------------------
arg_substitute = {}
for p in self._pars.heters.keys():
if p in code_scope:
arg_substitute[p] = f'{p}[_obj_i]'
if len(arg_substitute):
func_code = tools.word_replace(func_code, arg_substitute)

# add the main code (user defined)
# ------------------
for l in func_code.split('\n'):
code_lines.append(' ' + l)
code_lines.append('\n')
stripped_fname = tools.get_func_name(func, replace=True)
code_lines.insert(0, f'# "{stripped_fname}" step function of {self._name}')

# update code scope
# ------------------
for k in list(code_scope.keys()):
if k in self._pars.updates:
if profile.run_on_cpu():
# run on cpu :
# 1. update the parameter
# 2. remove the heterogeneous parameter
code_scope[k] = self._pars.updates[k]
if k in all_heter_pars:
all_heter_pars.remove(k)
else:
# run on gpu :
# 1. update the parameter
# 2. transform the heterogeneous parameter to function argument
if k in all_heter_pars:
code_args.add(k)
code_arg2call[k] = cuda.to_device(self._pars.updates[k])
else:
code_scope[k] = self._pars.updates[k]

# handle the "_normal_like_"
# ---------------------------
func_code = '\n'.join(code_lines)
if len(NoiseHandler.normal_pattern.findall(func_code)):
if profile.run_on_gpu(): # gpu noise
func_code = NoiseHandler.normal_pattern.sub(NoiseHandler.cuda_replace_f, func_code)
code_scope['xoroshiro128p_normal_float64'] = xoroshiro128p_normal_float64
num_block, num_thread = tools.get_cuda_size(self.ensemble.num)
code_args.add('rng_states')
code_arg2call['rng_states'] = f'{self._name}_runner.rng_states'
rng_state = create_xoroshiro128p_states(num_block * num_thread, seed=np.random.randint(100000))
setattr(self, 'rng_states', rng_state)
else: # cpu noise
func_code = NoiseHandler.normal_pattern.sub(NoiseHandler.scalar_replace_f, func_code)
code_scope['numpy'] = np
code_lines = func_code.split('\n')

# code to compile
# -----------------
# 1. get the codes to compile
code_to_compile = [f'def {stripped_fname}({tools.func_call(code_args)}):']
code_to_compile += code_lines
func_code = '\n '.join(code_to_compile)
exec(compile(func_code, '', 'exec'), code_scope)
# 2. output the function codes
if not profile.is_merge_steps():
if profile.show_format_code():
utils.show_code_str(func_code.replace('def ', f'def {self._name}_'))
if profile.show_code_scope():
utils.show_code_scope(code_scope, ['__builtins__', stripped_fname])
# 3. jit the compiled function
func = code_scope[stripped_fname]
if profile.run_on_cpu():
if profile.is_jit():
func = tools.jit(func)
else:
func = cuda.jit(func)
# 4. set the function to the model
setattr(self, stripped_fname, func)

# get function call
# -----------------
# 1. get the functional arguments
arg2calls = [code_arg2call[arg] for arg in sorted(list(code_args))]
arg_code = tools.func_call(arg2calls)
if profile.run_on_cpu():
# 2. function call on cpu
func_call = f'{self._name}_runner.{stripped_fname}({arg_code})'
else:
# 3. function call on gpu
num_block, num_thread = tools.get_cuda_size(self.ensemble.num)
func_call = f'{self._name}_runner.{stripped_fname}[{num_block}, {num_thread}]({arg_code})'

# the final result
# ------------------
results[stripped_fname] = {'scopes': code_scope,
'args': code_args,
'arg2calls': code_arg2call,
'codes': code_lines,
'call': func_call,
'num_data': self.ensemble.num}

# WARNING: heterogeneous parameter may not in the main step functions
if len(all_heter_pars) > 0:
raise errors.ModelDefError(f'''
Heterogeneous parameters "{list(all_heter_pars)}" are not defined
in main step function. BrainPy can not recognize.

This error may be caused by:
1. Heterogeneous par is defined in other non-main step functions.
2. Heterogeneous par is defined in "integrators", but do not call
"profile.set(merge_integrators=True)".

Several ways to correct this error is:
1. Define the heterogeneous parameter in the "ST".
2. Call "profile.set(merge_integrators=True)" define the network definition.

''')

return results

def merge_codes(self, compiled_result):
codes_of_calls = [] # call the compiled functions

if profile.run_on_cpu():
if profile.is_merge_steps():
lines, code_scopes, args, arg2calls = [], dict(), set(), dict()
for item in self.get_schedule():
if item in compiled_result:
lines.extend(compiled_result[item]['codes'])
code_scopes.update(compiled_result[item]['scopes'])
args = args | compiled_result[item]['args']
arg2calls.update(compiled_result[item]['arg2calls'])

args = sorted(list(args))
arg2calls_list = [arg2calls[arg] for arg in args]
lines.insert(0, f'\n# {self._name} "merge_func"'
f'\ndef merge_func({tools.func_call(args)}):')
func_code = '\n '.join(lines)
exec(compile(func_code, '', 'exec'), code_scopes)

func = code_scopes['merge_func']
if profile.is_jit():
func = tools.jit(func)
self.merge_func = func
func_call = f'{self._name}_runner.merge_func({tools.func_call(arg2calls_list)})'
codes_of_calls.append(func_call)

if profile.show_format_code():
utils.show_code_str(func_code.replace('def ', f'def {self._name}_'))
if profile.show_code_scope():
utils.show_code_scope(code_scopes, ('__builtins__', 'merge_func'))

else:
for item in self.get_schedule():
if item in compiled_result:
func_call = compiled_result[item]['call']
codes_of_calls.append(func_call)

else:
if profile.is_merge_steps():
print('WARNING: GPU mode do not support to merge steps.')

for item in self.get_schedule():
for compiled_key in compiled_result.keys():
if compiled_key.startswith(item):
func_call = compiled_result[compiled_key]['call']
codes_of_calls.append(func_call)
codes_of_calls.append('cuda.synchronize()')

return codes_of_calls

def get_schedule(self):
return self._schedule

def set_schedule(self, schedule):
if not isinstance(schedule, (list, tuple)):
raise errors.ModelUseError('"schedule" must be a list/tuple.')
all_func_names = ['input', 'monitor'] + self._step_names
for s in schedule:
if s not in all_func_names:
raise errors.ModelUseError(f'Unknown step function "{s}" for model "{self._name}".')
self._schedule = schedule

def set_data(self, key, data):
if profile.run_on_gpu():
if np.isscalar(data):
data_cuda = data
else:
data_cuda = cuda.to_device(data)
setattr(self, key, data_cuda)
else:
setattr(self, key, data)

def set_gpu_data(self, key, val):
if key not in self.gpu_data:
if isinstance(val, np.ndarray):
val = cuda.to_device(val)
elif isinstance(val, types.ObjState):
val = val.get_cuda_data()
setattr(self, key, val)
self.gpu_data[key] = val

def gpu_data_to_cpu(self):
for val in self.gpu_data.values():
val.to_host()


class TrajectoryRunner(Runner):
"""Runner class for trajectory.

Parameters
----------
ensemble : NeuGroup
The neuron ensemble.
target_vars : tuple, list
The targeted variables for trajectory.
fixed_vars : dict
The fixed variables.
"""

def __init__(self, ensemble, target_vars, fixed_vars=None):
# check ensemble
from brainpy.core.neurons import NeuGroup
if not isinstance(ensemble, NeuGroup):
raise errors.ModelUseError(f'{self.__name__} only supports the instance of NeuGroup.')

# initialization
super(TrajectoryRunner, self).__init__(ensemble=ensemble)

# check targeted variables
if not isinstance(target_vars, (list, tuple)):
raise errors.ModelUseError('"target_vars" must be a list/tuple.')
for var in target_vars:
if var not in self._model.variables:
raise errors.ModelUseError(f'"{var}" in "target_vars" is not defined in model "{self._model.name}".')
self.target_vars = target_vars

# check fixed variables
try:
if fixed_vars is not None:
isinstance(fixed_vars, dict)
else:
fixed_vars = dict()
except AssertionError:
raise errors.ModelUseError('"fixed_vars" must be a dict.')
self.fixed_vars = dict()
for integrator in self._model.integrators:
var_name = integrator.diff_eq.var_name
if var_name not in target_vars:
if var_name in fixed_vars:
self.fixed_vars[var_name] = fixed_vars.get(var_name)
else:
self.fixed_vars[var_name] = self._model.variables.get(var_name)
for var in fixed_vars.keys():
if var not in self.fixed_vars:
self.fixed_vars[var] = fixed_vars.get(var)

def format_step_code(self, func_code):
"""Format code of user defined step function.

Parameters
----------
func_code : str
The user defined function.
"""
tree = ast.parse(func_code.strip())
formatter = tools.LineFormatterForTrajectory(self.fixed_vars)
formatter.visit(tree)
return formatter

+ 0
- 244
brainpy/core/synapses.py View File

@@ -1,244 +0,0 @@
# -*- coding: utf-8 -*-

import re

import numpy as np

from . import base
from . import constants
from . import neurons
from .. import connectivity
from .. import errors
from .. import profile
from .. import tools

__all__ = [
'SynType',
'SynConn',
'delayed',
]

_SYN_CONN_NO = 0


class SynType(base.ObjType):
"""Abstract Synapse Type.

It can be defined based on a collection of synapses or a single synapse model.
"""

def __init__(self, name, ST, steps, mode='vector', requires=None, hand_overs=None, ):
if mode not in [constants.SCALAR_MODE, constants.VECTOR_MODE, constants.MATRIX_MODE]:
raise errors.ModelDefError('SynType only support "scalar", "vector" or "matrix".')

super(SynType, self).__init__(
ST=ST,
requires=requires,
steps=steps,
name=name,
mode=mode,
hand_overs=hand_overs)

# inspect delay keys
# ------------------

# delay function
delay_funcs = []
for func in self.steps:
if func.__name__.startswith('_brainpy_delayed_'):
delay_funcs.append(func)
if len(delay_funcs):
delay_func_code = '\n'.join([tools.deindent(tools.get_main_code(func)) for func in delay_funcs])
delay_func_code_left = '\n'.join(tools.format_code(delay_func_code).lefts)

# get delayed variables
_delay_keys = set()
delay_keys_in_left = set(re.findall(r'ST\[[\'"](\w+)[\'"]\]', delay_func_code_left))
if len(delay_keys_in_left) > 0:
raise errors.ModelDefError(f'Delayed function cannot assign value to "ST".')
delay_keys = set(re.findall(r'ST\[[\'"](\w+)[\'"]\]', delay_func_code))
if len(delay_keys) > 0:
_delay_keys.update(delay_keys)
self._delay_keys = list(_delay_keys)


class SynConn(base.Ensemble):
"""Synaptic connections.

Parameters
----------
model : SynType
The instantiated neuron type model.
pars_update : dict
Parameters to update.
pre_group : neurons.NeuGroup, neurons.NeuSubGroup
Pre-synaptic neuron group.
post_group : neurons.NeuGroup, neurons.NeuSubGroup
Post-synaptic neuron group.
conn : connectivity.Connector
Connection method to create synaptic connectivity.
num : int
The number of the synapses.
delay : float
The time of the synaptic delay.
monitors : list, tuple
Variables to monitor.
name : str
The name of the neuron group.
"""

def __init__(self, model, pre_group=None, post_group=None, conn=None, delay=0.,
name=None, monitors=None, satisfies=None, pars_update=None, ):
# name
# ----
if name is None:
global _SYN_CONN_NO
name = f'SynConn{_SYN_CONN_NO}'
_SYN_CONN_NO += 1
else:
name = name

# model
# ------
if not isinstance(model, SynType):
raise errors.ModelUseError(f'{type(self).__name__} receives an instance of {SynType.__name__}, '
f'not {type(model).__name__}.')

if model.mode == 'scalar':
if pre_group is None or post_group is None:
raise errors.ModelUseError('Using scalar-based synapse model must '
'provide "pre_group" and "post_group".')

# pre or post neuron group
# ------------------------
self.pre_group = pre_group
self.post_group = post_group
self.conn = None
num = 1
if pre_group is not None and post_group is not None:
# check
# ------
if not isinstance(pre_group, (neurons.NeuGroup, neurons.NeuSubGroup)):
raise errors.ModelUseError('"pre_group" must be an instance of NeuGroup/NeuSubGroup.')
if not isinstance(post_group, (neurons.NeuGroup, neurons.NeuSubGroup)):
raise errors.ModelUseError('"post_group" must be an instance of NeuGroup/NeuSubGroup.')

# pre and post synaptic state
self.pre = pre_group.ST
self.post = post_group.ST

if conn is not None:
# connections
# ------------
if isinstance(conn, connectivity.Connector):
self.conn = conn
self.conn(pre_group.indices, post_group.indices)
else:
if isinstance(conn, np.ndarray):
# check matrix dimension
if np.ndim(conn) != 2:
raise errors.ModelUseError(f'"conn" must be a 2D array, not {np.ndim(conn)}D.')
# check matrix shape
conn_shape = np.shape(conn)
if not (conn_shape[0] == pre_group.num and conn_shape[1] == post_group.num):
raise errors.ModelUseError(
f'The shape of "conn" must be ({pre_group.num}, {post_group.num})')
# get pre_ids and post_ids
pre_ids, post_ids = np.where(conn > 0)
else:
# check conn type
if not isinstance(conn, dict):
raise errors.ModelUseError(f'"conn" only support "dict", 2D ndarray, '
f'or instance of bp.connect.Connector.')
# check conn content
if not ('i' in conn and 'j' in conn):
raise errors.ModelUseError('When provided "conn" is a dict, "i" and "j" must in "conn".')
# get pre_ids and post_ids
pre_ids = np.asarray(conn['i'], dtype=np.int_)
post_ids = np.asarray(conn['j'], dtype=np.int_)
self.conn = connectivity.Connector()
self.conn.pre_ids = pre_group.indices.flatten()[pre_ids]
self.conn.post_ids = post_group.indices.flatten()[post_ids]

# get synaptic structures
self.conn.set_size(num_post=post_group.size, num_pre=pre_group.size)
if model.mode == constants.SCALAR_MODE:
self.conn.set_requires(model.step_args + ['post2syn', 'pre2syn'])
else:
self.conn.set_requires(model.step_args)
for k in self.conn.requires:
setattr(self, k, getattr(self.conn, k))
self.pre_ids = self.conn.pre_ids
self.post_ids = self.conn.post_ids
num = len(self.pre_ids)

if satisfies is not None and 'num' in satisfies:
num = satisfies['num']

try:
assert 0 < num < 2 ** 64
except AssertionError:
raise errors.ModelUseError('Total synapse number "num" must be a valid number in "uint64".')

# initialize
# ----------
super(SynConn, self).__init__(model=model,
pars_update=pars_update,
name=name,
num=num,
monitors=monitors,
cls_type=constants.SYN_CONN_TYPE,
satisfies=satisfies)

# delay
# -------
if delay is None:
delay_len = 1
elif isinstance(delay, (int, float)):
dt = profile.get_dt()
delay_len = int(np.ceil(delay / dt))
if delay_len == 0:
delay_len = 1
else:
raise ValueError("BrainPy currently doesn't support other kinds of delay.")
self.delay_len = delay_len # delay length

# ST
# --
if self.model.mode == constants.MATRIX_MODE:
if pre_group is None:
if 'pre_size' not in satisfies:
raise errors.ModelUseError('"pre_size" must be provided in "satisfies" when "pre_group" is none.')
pre_size = satisfies['pre_size']
else:
pre_size = pre_group.size

if post_group is None:
if 'post_size' not in satisfies:
raise errors.ModelUseError('"post_size" must be provided in "satisfies" when "post_group" is none.')
post_size = satisfies['post_size']
else:
post_size = post_group.size
size = (pre_size, post_size)
else:
size = (self.num,)
self.ST = self.model.ST.make_copy(size=size,
delay=delay_len,
delay_vars=self.model._delay_keys)


def delayed(func):
"""Decorator for synapse delay.

Parameters
----------
func : callable
The step function which use delayed synapse state.

Returns
-------
func : callable
The modified step function.
"""
func.__name__ = f'_brainpy_delayed_{func.__name__}'
return func

+ 0
- 442
brainpy/core/types.py View File

@@ -1,442 +0,0 @@
# -*- coding: utf-8 -*-

import math
from collections import OrderedDict

import numba as nb
import numpy as np
from numba import cuda

from .. import errors
from .. import profile

__all__ = [
'TypeChecker',
'ObjState',
'NeuState',
'SynState',
'gpu_set_vector_val',
'ListConn',
'MatConn',
'Array',
'Int',
'Float',
'List',
'Dict',
]


class TypeChecker(object):
def __init__(self, help):
self.help = help

def check(self, cls):
raise NotImplementedError

@classmethod
def make_copy(cls, *args, **kwargs):
raise NotImplementedError


def gpu_set_scalar_val(data, val, idx):
i = cuda.grid(1)
if i < data.shape[1]:
data[idx, i] = val


def gpu_set_vector_val(data, val, idx):
i = cuda.grid(1)
if i < data.shape[1]:
data[idx, i] = val[i]


if cuda.is_available():
gpu_set_scalar_val = cuda.jit('(float64[:, :], float64, int64)')(gpu_set_scalar_val)
gpu_set_vector_val = cuda.jit('(float64[:, :], float64[:], int64)')(gpu_set_vector_val)


class ObjState(dict, TypeChecker):
def __init__(self, *args, help='', **kwargs):
# 1. initialize TypeChecker
TypeChecker.__init__(self, help=help)

# 2. get variables
variables = OrderedDict()
for a in args:
if isinstance(a, str):
variables[a] = 0.
elif isinstance(a, (tuple, list)):
for v in a:
variables[v] = 0.
elif isinstance(a, dict):
for key, val in a.items():
if not isinstance(val, (int, float)):
raise ValueError(f'The default value setting in a dict must be int/float.')
variables[key] = val
else:
raise ValueError(f'Only support str/tuple/list/dict, not {type(variables)}.')
for key, val in kwargs.items():
if not isinstance(val, (int, float)):
raise ValueError(f'The default value setting must be int/float.')
variables[key] = val

# 3. others
self._keys = list(variables.keys())
self._values = list(variables.values())
self._vars = variables

def check(self, cls):
if not isinstance(cls, type(self)):
raise errors.TypeMismatchError(f'Must be an instance of "{type(self)}", but got "{type(cls)}".')
for k in self._keys:
if k not in cls:
raise errors.TypeMismatchError(f'Key "{k}" is not found in "cls".')

def get_cuda_data(self):
_data_cuda = self.__getitem__('_data_cuda')
if _data_cuda is None:
_data = self.__getitem__('_data')
_data_cuda = cuda.to_device(_data)
super(ObjState, self).__setitem__('_data_cuda', _data_cuda)
return _data_cuda

def __setitem__(self, key, val):
if key in self._vars:
# get data
data = self.__getitem__('_data')
_var2idx = self.__getitem__('_var2idx')
idx = _var2idx[key]
# gpu setattr
if profile.run_on_gpu():
gpu_data = self.get_cuda_data()
if data.shape[1] <= profile._num_thread_gpu:
num_thread = data.shape[1]
num_block = 1
else:
num_thread = profile._num_thread_gpu
num_block = math.ceil(data.shape[1] / profile._num_thread_gpu)
if np.isscalar(val):
gpu_set_scalar_val[num_block, num_thread](gpu_data, val, idx)
else:
if val.shape[0] != data.shape[1]:
raise ValueError(f'Wrong value dimension {val.shape[0]} != {data.shape[1]}')
gpu_set_vector_val[num_block, num_thread](gpu_data, val, idx)
cuda.synchronize()
# cpu setattr
else:
data[idx] = val
elif key in ['_data', '_var2idx', '_idx2var']:
raise KeyError(f'"{key}" cannot be modified.')
else:
raise KeyError(f'"{key}" is not defined in {type(self).__name__}, '
f'only finds "{str(self._keys)}".')

def __str__(self):
return f'{self.__class__.__name__} ({str(self._keys)})'

def __repr__(self):
return self.__str__()


class NeuState(ObjState):
"""Neuron State Management. """

def __call__(self, size):
if isinstance(size, int):
size = (size,)
elif isinstance(size, (tuple, list)):
size = tuple(size)
else:
raise ValueError(f'Unknown size type: {type(size)}.')

data = np.zeros((len(self._vars),) + size, dtype=np.float_)
var2idx = dict()
idx2var = dict()
state = dict()
for i, (k, v) in enumerate(self._vars.items()):
state[k] = data[i]
data[i] = v
var2idx[k] = i
idx2var[i] = k
state['_data'] = data
state['_data_cuda'] = None
state['_var2idx'] = var2idx
state['_idx2var'] = idx2var

dict.__init__(self, state)

return self

def make_copy(self, size):
obj = NeuState(self._vars)
return obj(size=size)


@nb.njit([nb.types.UniTuple(nb.int64[:], 2)(nb.int64[:], nb.int64[:], nb.int64[:]),
nb.types.UniTuple(nb.int64, 2)(nb.int64, nb.int64, nb.int64)])
def update_delay_indices(delay_in, delay_out, delay_len):
_delay_in = (delay_in + 1) % delay_len
_delay_out = (delay_out + 1) % delay_len
return _delay_in, _delay_out


class SynState(ObjState):
"""Synapse State Management. """

def __init__(self, *args, help='', **kwargs):
super(SynState, self).__init__(*args, help=help, **kwargs)
self._delay_len = 1
self._delay_in = 0
self._delay_out = 0

def __call__(self, size, delay=None, delay_vars=()):
# check size
if isinstance(size, int):
size = (size,)
elif isinstance(size, (tuple, list)):
size = tuple(size)
else:
raise ValueError(f'Unknown size type: {type(size)}.')

# check delay
delay = 0 if (delay is None) or (delay < 1) else delay
assert isinstance(delay, int), '"delay" must be a int to specify the delay length.'
self._delay_len = delay
self._delay_in = delay - 1

# check delay_vars
if isinstance(delay_vars, str):
delay_vars = (delay_vars,)
elif isinstance(delay_vars, (tuple, list)):
delay_vars = tuple(delay_vars)
else:
raise ValueError(f'Unknown delay_vars type: {type(delay_vars)}.')

# initialize data
length = len(self._vars) + delay * len(delay_vars)
data = np.zeros((length,) + size, dtype=np.float_)
var2idx = dict()
idx2var = dict()
state = dict()
for i, (k, v) in enumerate(self._vars.items()):
data[i] = v
state[k] = data[i]
var2idx[k] = i
idx2var[i] = k
index_offset = len(self._vars)
for i, v in enumerate(delay_vars):
var2idx[f'_{v}_offset'] = i * delay + index_offset
state[f'_{v}_delay'] = data[i * delay + index_offset: (i + 1) * delay + index_offset]
state['_data'] = data
state['_data_cuda'] = None
state['_var2idx'] = var2idx
state['_idx2var'] = idx2var

dict.__init__(self, state)

return self

def make_copy(self, size, delay=None, delay_vars=()):
obj = SynState(self._vars)
return obj(size=size, delay=delay, delay_vars=delay_vars)

def delay_push(self, g, var):
if self._delay_len > 0:
data = self.__getitem__('_data')
offset = self.__getitem__('_var2idx')[f'_{var}_offset']
data[self._delay_in + offset] = g

def delay_pull(self, var):
if self._delay_len > 0:
data = self.__getitem__('_data')
offset = self.__getitem__('_var2idx')[f'_{var}_offset']
return data[self._delay_out + offset]
else:
data = self.__getitem__('_data')
var2idx = self.__getitem__('_var2idx')
return data[var2idx[var]]

def _update_delay_indices(self):
din, dout = update_delay_indices(self._delay_in, self._delay_out, self._delay_len)
self._delay_in = din
self._delay_out = dout


class ListConn(TypeChecker):
"""Synaptic connection with list type."""

def __init__(self, help=''):
super(ListConn, self).__init__(help=help)

def check(self, cls):
if profile.is_jit():
if not isinstance(cls, nb.typed.List):
raise errors.TypeMismatchError(f'In numba mode, "cls" must be an instance of {type(nb.typed.List)}, '
f'but got {type(cls)}. Hint: you can use "ListConn.create()" method.')
if not isinstance(cls[0], (nb.typed.List, np.ndarray)):
raise errors.TypeMismatchError(f'In numba mode, elements in "cls" must be an instance of '
f'{type(nb.typed.List)} or ndarray, but got {type(cls[0])}. '
f'Hint: you can use "ListConn.create()" method.')
else:
if not isinstance(cls, list):
raise errors.TypeMismatchError(f'ListConn requires a list, but got {type(cls)}.')
if not isinstance(cls[0], (list, np.ndarray)):
raise errors.TypeMismatchError(f'ListConn requires the elements of the list must be list or '
f'ndarray, but got {type(cls)}.')

@classmethod
def make_copy(cls, conn):
assert isinstance(conn, (list, tuple)), '"conn" must be a tuple/list.'
assert isinstance(conn[0], (list, tuple)), 'Elements of "conn" must be tuple/list.'
if profile.is_jit():
a_list = nb.typed.List()
for l in conn:
a_list.append(np.uint64(l))
else:
a_list = conn
return a_list

def __str__(self):
return 'ListConn'


class MatConn(TypeChecker):
"""Synaptic connection with matrix (2d array) type."""

def __init__(self, help=''):
super(MatConn, self).__init__(help=help)

def check(self, cls):
if not (isinstance(cls, np.ndarray) and np.ndim(cls) == 2):
raise errors.TypeMismatchError(f'MatConn requires a two-dimensional ndarray.')

def __str__(self):
return 'MatConn'


class SliceConn(TypeChecker):
def __init__(self, help=''):
super(SliceConn, self).__init__(help=help)

def check(self, cls):
if not (isinstance(cls, np.ndarray) and np.shape[1] == 2):
raise errors.TypeMismatchError(f'')

def __str__(self):
return 'SliceConn'


class Array(TypeChecker):
"""NumPy ndarray."""

def __init__(self, dim, help=''):
self.dim = dim
super(Array, self).__init__(help=help)

def __call__(self, size):
if isinstance(size, int):
assert self.dim == 1
else:
assert len(size) == self.dim
return np.zeros(size, dtype=np.float_)

def check(self, cls):
if not (isinstance(cls, np.ndarray) and np.ndim(cls) == self.dim):
raise errors.TypeMismatchError(f'MatConn requires a {self.dim}-D ndarray.')

def __str__(self):
return type(self).__name__ + f' (dim={self.dim})'


class String(TypeChecker):
def __init__(self, help=''):
super(String, self).__init__(help=help)

def check(self, cls):
if not isinstance(cls, str):
raise errors.TypeMismatchError(f'Require a string, got {type(cls)}.')

def __str__(self):
return 'StringType'


class Int(TypeChecker):
def __init__(self, help=''):
super(Int, self).__init__(help=help)

def check(self, cls):
if not isinstance(cls, int):
raise errors.TypeMismatchError(f'Require an int, got {type(cls)}.')

def __str__(self):
return 'IntType'


class Float(TypeChecker):
def __init__(self, help=''):
super(Float, self).__init__(help=help)

def check(self, cls):
if not isinstance(cls, float):
raise errors.TypeMismatchError(f'Require a float, got {type(cls)}.')

def __str__(self):
return 'Floatype'


class List(TypeChecker):
def __init__(self, item_type=None, help=''):
if item_type is None:
self.item_type = None
else:
assert isinstance(item_type, TypeChecker), 'Must be a TypeChecker.'
self.item_type = item_type

super(List, self).__init__(help=help)

def check(self, cls):
if profile.is_jit():
if not isinstance(cls, nb.typed.List):
raise errors.TypeMismatchError(f'In numba, "List" requires an instance of {type(nb.typed.List)}, '
f'but got {type(cls)}.')
else:
if not isinstance(cls, list):
raise errors.TypeMismatchError(f'"List" requires an instance of list, '
f'but got {type(cls)}.')

if self.item_type is not None:
self.item_type.check(cls[0])

def __str__(self):
return type(self).__name__ + f'(item_type={str(self.item_type)})'


class Dict(TypeChecker):
def __init__(self, key_type=String, item_type=None, help=''):
if key_type is not None:
assert isinstance(key_type, TypeChecker), 'Must be a TypeChecker.'
self.key_type = key_type
if item_type is not None:
assert isinstance(item_type, TypeChecker), 'Must be a TypeChecker.'
self.item_type = item_type
super(Dict, self).__init__(help=help)

def check(self, cls):
if profile.is_jit():
if not isinstance(cls, nb.typed.Dict):
raise errors.TypeMismatchError(f'In numba, "Dict" requires an instance of {type(nb.typed.Dict)}, '
f'but got {type(cls)}.')
else:
if not isinstance(cls, dict):
raise errors.TypeMismatchError(f'"Dict" requires an instance of dict, '
f'but got {type(cls)}.')

if self.key_type is not None:
for key in cls.keys():
self.key_type.check(key)
if self.item_type is not None:
for item in cls.items():
self.item_type.check(item)

def __str__(self):
return type(self).__name__ + f'(key_type={str(self.key_type)}, item_type={str(self.item_type)})'

+ 0
- 151
brainpy/core/utils.py View File

@@ -1,151 +0,0 @@
# -*- coding: utf-8 -*-

import inspect
from pprint import pprint

from numba.core.dispatcher import Dispatcher

from .. import backend
from .. import errors
from .. import integration
from .. import tools

__all__ = [
'show_code_str',
'show_code_scope',
'find_integrators',
'get_func_scope',
'check_slice',
]


def check_slice(start, end, length):
if start >= end:
raise errors.ModelUseError(f'Illegal start/end values for subgroup, {start}>={end}')
if start >= length:
raise errors.ModelUseError(f'Illegal start value for subgroup, {start}>={length}')
if end > length:
raise errors.ModelUseError(f'Illegal stop value for subgroup, {end}>{length}')
if start < 0:
raise errors.ModelUseError('Indices have to be positive.')


def show_code_str(func_code):
print(func_code)
print()


def show_code_scope(code_scope, ignores=()):
scope = {}
for k, v in code_scope.items():
if k in ignores:
continue
if k in integration.CONSTANT_MAPPING:
continue
if k in integration.FUNCTION_MAPPING:
continue
scope[k] = v
pprint(scope)
print()


def find_integrators(func):
"""Find integrators in a given function.

Parameters
----------
func : callable
The function.

Returns
-------
integrators : list
A list of integrators.
"""
if not callable(func) or type(func).__name__ != 'function':
return []

integrals = []
variables = inspect.getclosurevars(func)
scope = dict(variables.nonlocals)
scope.update(variables.globals)
for val in scope.values():
if isinstance(val, integration.Integrator):
integrals.append(val)
elif callable(val):
integrals.extend(find_integrators(val))
return integrals


def _update_scope(k, v, scope):
if type(v).__name__ in ['module', 'function']:
return
if isinstance(v, integration.Integrator):
return
if k in scope:
if v != scope[k]:
raise ValueError(f'Find scope variable {k} have different values: \n'
f'{k} = {v} and {k} = {scope[k]}. \n'
f'This maybe cause a grievous mistake in the future. Please change!')
scope[k] = v


def get_func_scope(func, include_dispatcher=False):
"""Get function scope variables.

Parameters
----------
func : callable, Integrator
include_dispatcher

Returns
-------

"""
# get function scope
if isinstance(func, integration.Integrator):
func_name = func.py_func_name
variables = inspect.getclosurevars(func.diff_eq.func)
scope = dict(variables.nonlocals)
scope.update(variables.globals)
elif type(func).__name__ == 'function':
func_name = tools.get_func_name(func, replace=True)
variables = inspect.getclosurevars(func)
if func_name.startswith('xoroshiro128p_'):
return {}
scope = dict(variables.nonlocals)
scope.update(variables.globals)
else:
if backend.func_in_numpy_or_math(func):
return {}
elif isinstance(func, Dispatcher) and include_dispatcher:
scope = get_func_scope(func.py_func)
else:
raise ValueError(f'Unknown type: {type(func)}')

# update scope
for k, v in list(scope.items()):
# get the scope of the function item
if callable(v):
if isinstance(v, Dispatcher):
if include_dispatcher:
for k2, v2 in get_func_scope(v.py_func).items():
try:
_update_scope(k2, v2, scope)
except ValueError:
raise ValueError(f'Definition error in function "{func_name}".')
else:
for k2, v2 in get_func_scope(v).items():
try:
_update_scope(k2, v2, scope)
except ValueError:
raise ValueError(f'Definition error in function "{func_name}".')

for k in list(scope.keys()):
v = scope[k]
if type(v).__name__ in ['module', 'function']:
scope.pop(k)
if isinstance(v, integration.Integrator):
scope.pop(k)

return scope

+ 4
- 7
brainpy/errors.py View File

@@ -11,21 +11,18 @@ class ModelUseError(Exception):
pass


class TypeMismatchError(Exception):
class DiffEqError(Exception):
pass


class IntegratorError(Exception):
class CodeError(Exception):
pass


class DiffEquationError(Exception):
class AnalyzerError(Exception):
pass


class CodeError(Exception):
class PackageMissingError(Exception):
pass


class AnalyzerError(Exception):
pass

+ 4
- 268
brainpy/inputs.py View File

@@ -1,22 +1,13 @@
# -*- coding: utf-8 -*-

import numpy as np
from numba.cuda import random

from . import profile
from . import tools
from .core import NeuGroup
from .core import NeuType
from .core.types import NeuState
from .errors import ModelUseError
from brainpy import backend

__all__ = [
'constant_current',
'spike_current',
'ramp_current',
'PoissonInput',
'SpikeTimeInput',
'FreqInput',
]


@@ -43,7 +34,7 @@ def constant_current(Iext, dt=None):
current_and_duration : tuple
(The formatted current, total duration)
"""
dt = profile.get_dt() if dt is None else dt
dt = backend.get_dt() if dt is None else dt

# get input current dimension, shape, and duration
I_duration = 0.
@@ -98,7 +89,7 @@ def spike_current(points, lengths, sizes, duration, dt=None):
current_and_duration : tuple
(The formatted current, total duration)
"""
dt = profile.get_dt() if dt is None else dt
dt = backend.get_dt() if dt is None else dt
assert isinstance(points, (list, tuple))
if isinstance(lengths, (float, int)):
lengths = [lengths] * len(points)
@@ -135,7 +126,7 @@ def ramp_current(c_start, c_end, duration, t_start=0, t_end=None, dt=None):
current_and_duration : tuple
(The formatted current, total duration)
"""
dt = profile.get_dt() if dt is None else dt
dt = backend.get_dt() if dt is None else dt
t_end = duration if t_end is None else t_end

current = np.zeros(int(np.ceil(duration / dt)))
@@ -143,258 +134,3 @@ def ramp_current(c_start, c_end, duration, t_start=0, t_end=None, dt=None):
p2 = int(np.ceil(t_end / dt))
current[p1: p2] = np.linspace(c_start, c_end, p2 - p1)
return current


class PoissonInput(NeuGroup):
"""The Poisson input neuron group.

Note: The ``PoissonGroup`` does not work for high-frequency rates. This is because
more than one spike might fall into a single time step (``dt``).
However, you can split high frequency rates into several neurons with lower frequency rates.
For example, use ``PoissonGroup(10, 100)`` instead of ``PoissonGroup(1, 1000)``.

Parameters
----------
geometry : int, tuple, list
The neuron group geometry.
freqs : float, int, np.ndarray
The spike rates.
monitors : list, tuple
The targets for monitoring.
name : str
The neuron group name.
"""

def __init__(self, geometry, freqs, monitors=None, name=None):
dt = profile.get_dt() / 1000.

# firing rate
if isinstance(freqs, np.ndarray):
freqs = freqs.flatten()
if not np.all(freqs <= 1000. / profile.get_dt()):
print(f'WARNING: The maximum supported frequency at dt={profile.get_dt()} ms '
f'is {1000. / profile.get_dt()} Hz. While we get your "freq" setting which '
f'is bigger than that.')

# neuron model on CPU
# -------------------
if profile.run_on_cpu():
def update(ST):
ST['spike'] = np.random.random(ST['spike'].shape) < freqs * dt

model = NeuType(name='poisson_input', ST=NeuState('spike'), steps=update, mode='vector')

# neuron model on GPU
# -------------------
else:
def update(ST, rng_states, _obj_i):
ST['spike'] = random.xoroshiro128p_uniform_float64(rng_states, _obj_i) < freqs * dt

model = NeuType(name='poisson_input', ST=NeuState('spike'), steps=update, mode='scalar')

# initialize neuron group
# -----------------------
super(PoissonInput, self).__init__(model=model, geometry=geometry, monitors=monitors, name=name)

# will automatically handle
# the heterogeneous problem
# -------------------------
self.pars['freqs'] = freqs

# rng states
# ----------
if profile.run_on_gpu():
num_block, num_thread = tools.get_cuda_size(self.num)
self.rng_states = random.create_xoroshiro128p_states(
num_block * num_thread, seed=np.random.randint(100000))


class SpikeTimeInput(NeuGroup):
"""The input neuron group characterized by spikes emitting at given times.

>>> # Get 2 neurons, firing spikes at 10 ms and 20 ms.
>>> SpikeTimeInput(2, times=[10, 20])
>>> # or
>>> # Get 2 neurons, the neuron 0 fires spikes at 10 ms and 20 ms.
>>> SpikeTimeInput(2, times=[10, 20], indices=[0, 0])
>>> # or
>>> # Get 2 neurons, neuron 0 fires at 10 ms and 30 ms, neuron 1 fires at 20 ms.
>>> SpikeTimeInput(2, times=[10, 20, 30], indices=[0, 1, 0])
>>> # or
>>> # Get 2 neurons; at 10 ms, neuron 0 fires; at 20 ms, neuron 0 and 1 fire;
>>> # at 30 ms, neuron 1 fires.
>>> SpikeTimeInput(2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1])

Parameters
----------
geometry : int, tuple, list
The neuron group geometry.
indices : int, list, tuple
The neuron indices at each time point to emit spikes.
times : list, np.ndarray
The time points which generate the spikes.
monitors : list, tuple
The targets for monitoring.
name : str
The group name.
"""

def __init__(self, geometry, times, indices=None, monitors=None, name=None, need_sort=True):
# number of neurons
# -----------------
if isinstance(geometry, (int, float)):
num = int(geometry)
elif isinstance(geometry, (tuple, list)):
num = int(np.prod(geometry))
else:
raise ModelUseError(f'"geometry" must be a int, or a tuple/list of int, '
f'but we got {type(geometry)}.')

# indices is not provided
# -----------------------
if indices is None:
# data about times
times = np.ascontiguousarray(times, dtype=np.float_)
if need_sort: times = np.sort(times)
num_times = len(times)

# model on CPU
if profile.run_on_cpu():
def update(ST, _t, idx):
in_idx = idx[0]
if (in_idx < num_times) and (_t >= times[in_idx]):
ST['spike'] = 1.
idx += 1
else:
ST['spike'] = 0.

model = NeuType(name='time_input', ST=NeuState('spike'),
steps=update, mode='vector',
hand_overs={'idx': np.array([0])})

else:
def update(ST, _t, idxs, _obj_i):
in_idx = idxs[_obj_i]
if (in_idx < num_times) and (_t >= times[in_idx]):
ST['spike'] = 1.
idxs[_obj_i] += 1
else:
ST['spike'] = 0.

model = NeuType(name='time_input', ST=NeuState('spike'),
steps=update, mode='scalar',
hand_overs={'idxs': np.zeros(num, dtype=np.int_)})

# indices and times are provided
# ------------------------------

else:
if len(indices) != len(times):
raise ModelUseError(f'The length of "indices" and "times" must be the same. '
f'However, we got {len(indices)} != {len(times)}.')

if profile.run_on_cpu():

# data about times and indices
times = np.ascontiguousarray(times, dtype=np.float_)
indices = np.ascontiguousarray(indices, dtype=np.int_)
num_times = len(times)
if need_sort:
sort_idx = np.argsort(times)
indices = indices[sort_idx]

# update logic
def update(ST, _t, idx):
ST['spike'] = 0.
while idx[0] < num_times and _t >= times[idx[0]]:
ST['spike'][indices[idx[0]]] = 1.
idx += 1

model = NeuType(name='time_input', ST=NeuState('spike'),
steps=update, mode='vector',
hand_overs={'idx': np.array([0])})

else:
raise NotImplementedError

# neuron group
super(SpikeTimeInput, self).__init__(model=model,
geometry=geometry,
monitors=monitors,
name=name)


class FreqInput(NeuGroup):
"""The input neuron group characterized by frequency.

For examples:

>>> # Get 2 neurons, with 10 Hz firing rate.
>>> FreqInput(2, freq=10.)
>>> # Get 4 neurons, with 20 Hz firing rate. The neurons
>>> # start firing at [10, 30] ms randomly.
>>> FreqInput(4, freq=20., start_time=np.random.randint(10, 30, (4,)))

Parameters
----------
geometry : int, list, tuple
The geometry of neuron group.
freqs : int, float, np.ndarray
The output spike frequency.
start_time : float
The time of the first spike.
monitors : list, tuple
The targets for monitoring.
name : str
The name of the neuron group.
"""

def __init__(self, geometry, freqs, start_time=0., monitors=None, name=None):
if not np.allclose(freqs <= 1000. / profile.get_dt()):
print(f'WARNING: The maximum supported frequency at dt={profile.get_dt()} ms '
f'is {1000. / profile.get_dt()} Hz. While we get your "freq" setting which '
f'is bigger than that.')

state = NeuState({'spike': 0., 't_next_spike': 0., 't_last_spike': -1e7})

if profile.is_jit():
def update_state(ST, _t_):
if _t_ >= ST['t_next_spike']:
ST['spike'] = 1.
ST['t_last_spike'] = _t_
ST['t_next_spike'] += 1000. / freqs
else:
ST['spike'] = 0.

model = NeuType(name='poisson_input',
ST=state,
steps=update_state,
mode='scalar')

else:
if np.size(freqs) == 1:
def update_state(ST, _t_):
should_spike = _t_ >= ST['t_next_spike']
ST['spike'] = should_spike
spike_ids = np.where(should_spike)[0]
ST['t_last_spike'][spike_ids] = _t_
ST['t_next_spike'][spike_ids] += 1000. / freqs

else:
def update_state(ST, _t_):
should_spike = _t_ >= ST['t_next_spike']
ST['spike'] = should_spike
spike_ids = np.where(should_spike)[0]
ST['t_last_spike'][spike_ids] = _t_
ST['t_next_spike'][spike_ids] += 1000. / freqs[spike_ids]

model = NeuType(name='freq_input',
ST=state,
steps=update_state,
mode='vector')

# neuron group
super(FreqInput, self).__init__(model=model, geometry=geometry, monitors=monitors, name=name)

self.ST['t_next_spike'] = start_time
self.pars['freqs'] = freqs

+ 0
- 74
brainpy/integration/__init__.py View File

@@ -1,74 +0,0 @@
# -*- coding: utf-8 -*-

from . import diff_equation
from . import integrator
from . import utils
from .diff_equation import *
from .integrator import *
from .utils import *
from .. import profile

_SUPPORT_METHODS = [
'euler',
'midpoint',
'heun',
'rk2',
'rk3',
'rk4',
'rk4_alternative',
'exponential',
'milstein',
'milstein_ito',
'milstein_stra',
]


def integrate(func=None, method=None):
"""Generate the one-step integrator function for differential equations.

Using this method, the users only need to define the right side of the equation.
For example, for the `m` channel in the Hodgkin–Huxley neuron model

.. math::

\\alpha = {0.1 * (V + 40 \\over 1 - \\exp(-(V + 40) / 10)}

\\beta = 4.0 * \\exp(-(V + 65) / 18)

{dm \\over dt} = \\alpha * (1 - m) - \\beta * m

Using ``BrainPy``, this ODE function can be written as

>>> import numpy as np
>>> from brainpy import integrate
>>>
>>> @integrate(method='rk4')
>>> def int_m(m, t, V):
>>> alpha = 0.1 * (V + 40) / (1 - np.exp(-(V + 40) / 10))
>>> beta = 4.0 * np.exp(-(V + 65) / 18)
>>> return alpha * (1 - m) - beta * m

Parameters
----------
func : callable
The function at the right hand of the differential equation.
If a stochastic equation (SDE) is defined, then `func` is the drift coefficient
(the deterministic part) of the SDE.
method : None, str, callable
The method of numerical integrator.

Returns
-------
integrator : Integrator
If `f` is provided, then the one-step numerical integrator will be returned.
if not, the wrapper will be provided.
"""

method = method if method is not None else profile.get_numerical_method()
_integrator_ = get_integrator(method)

if func is None:
return lambda f: _integrator_(DiffEquation(func=f))

else:
return _integrator_(DiffEquation(func=func))

+ 0
- 24
brainpy/integration/constants.py View File

@@ -1,24 +0,0 @@
# -*- coding: utf-8 -*-


CONSTANT_NOISE = 'CONSTANT'
FUNCTIONAL_NOISE = 'FUNCTIONAL'

ODE_TYPE = 'ODE'
SDE_TYPE = 'SDE'

DIFF_EQUATION = 'diff_equation'
SUB_EXPRESSION = 'sub_expression'
RETURN_TYPES = [
# return type # multi_return # DF type
'x', # False # ODE
'x,0', # False # ODE [x]
'(x,),', # False # ODE
'(x,),...', # True # ODE
'(x,0),', # False # ODE [x]
'(x,0),...', # True # ODE [x]
'x,x', # False # SDE
'(x,x),', # False # SDE
'(x,x),...', # True # SDE
]


+ 0
- 332
brainpy/integration/diff_equation.py View File

@@ -1,332 +0,0 @@
# -*- coding: utf-8 -*-

import inspect
from collections import Counter

import sympy

from . import constants
from . import utils
from .. import errors
from .. import profile
from .. import tools

__all__ = [
'Expression',
'DiffEquation',
]


class Expression(object):
def __init__(self, var, code):
self.var_name = var
self.code = code.strip()
self._substituted_code = None

@property
def identifiers(self):
return tools.get_identifiers(self.code)

def __str__(self):
return f'{self.var_name} = {self.code}'

def __repr__(self):
return self.__str__()

def __eq__(self, other):
if not isinstance(other, Expression):
return NotImplemented
if self.code != other.code:
return False
if self.var_name != other.var_name:
return False
return True

def __ne__(self, other):
return not self.__eq__(other)

def get_code(self, subs=True):
if subs:
if self._substituted_code is None:
return self.code
else:
return self._substituted_code
else:
return self.code


class DiffEquation(object):
"""Differential Equation.

A differential equation is defined as the standard form:

dx/dt = f(x) + g(x) dW

Parameters
----------
func : callable
The user defined differential equation.
"""

def __init__(self, func):
# check
if func is None:
raise errors.DiffEquationError('"func" cannot be None.')
if not (callable(func) and type(func).__name__ == 'function'):
raise errors.DiffEquationError('"func" must be a function.')

# function
self.func = func

# function string
self.code = tools.deindent(tools.get_main_code(func))
if 'return' not in self.code:
raise errors.DiffEquationError(f'"func" function must return something, '
f'but found no return.\n{self.code}')

# function arguments
self.func_args = inspect.getfullargspec(func).args

# function name
if tools.is_lambda_function(func):
self.func_name = f'_integral_{self.func_args[0]}_'
else:
self.func_name = func.__name__

# function scope
scope = inspect.getclosurevars(func)
self.func_scope = dict(scope.nonlocals)
self.func_scope.update(scope.globals)

# differential variable name and time name
self.var_name = self.func_args[0]
self.t_name = self.func_args[1]

# analyse function code
res = utils.analyse_diff_eq(self.code)
self.expressions = [Expression(v, expr) for v, expr in zip(res.variables, res.expressions)]
self.returns = res.returns
self.return_type = res.return_type
self.f_expr = None
self.g_expr = None
if res.f_expr is not None:
self.f_expr = Expression(res.f_expr[0], res.f_expr[1])
if res.g_expr is not None:
self.g_expr = Expression(res.g_expr[0], res.g_expr[1])
for k, num in Counter(res.variables).items():
if num > 1:
raise errors.DiffEquationError(
f'Found "{k}" {num} times. Please assign each expression '
f'in differential function with a unique name. ')

# analyse noise type
self.g_type = constants.CONSTANT_NOISE
self.g_value = None
if self.g_expr is not None:
self._substitute(self.g_expr, self.expressions)
g_code = self.g_expr.get_code(subs=True)
for idf in tools.get_identifiers(g_code):
if idf not in self.func_scope:
self.g_type = constants.FUNCTIONAL_NOISE
break
else:
self.g_value = eval(g_code, self.func_scope)

def _substitute(self, final_exp, expressions, substitute_vars=None):
"""Substitute expressions to get the final single expression

Parameters
----------
final_exp : Expression
The final expression.
expressions : list, tuple
The list/tuple of expressions.
"""
if substitute_vars is None:
return
if final_exp is None:
return
assert substitute_vars == 'all' or \
substitute_vars == self.var_name or \
isinstance(substitute_vars, (tuple, list))

# Goal: Substitute dependent variables into the expresion
# Hint: This step doesn't require the left variables are unique
dependencies = {}
for expr in expressions:
substitutions = {}
for dep_var, dep_expr in dependencies.items():
if dep_var in expr.identifiers:
code = dep_expr.get_code(subs=True)
substitutions[sympy.Symbol(dep_var, real=True)] = utils.str2sympy(code).expr
if len(substitutions):
new_sympy_expr = utils.str2sympy(expr.code).expr.xreplace(substitutions)
new_str_expr = utils.sympy2str(new_sympy_expr)
expr._substituted_code = new_str_expr
dependencies[expr.var_name] = expr
else:
if substitute_vars == 'all':
dependencies[expr.var_name] = expr
elif substitute_vars == self.var_name:
if self.var_name in expr.identifiers:
dependencies[expr.var_name] = expr
else:
ids = expr.identifiers
for var in substitute_vars:
if var in ids:
dependencies[expr.var_name] = expr
break

# Goal: get the final differential equation
# Hint: the step requires the expression variables must be unique
substitutions = {}
for dep_var, dep_expr in dependencies.items():
code = dep_expr.get_code(subs=True)
substitutions[sympy.Symbol(dep_var, real=True)] = utils.str2sympy(code).expr
if len(substitutions):
new_sympy_expr = utils.str2sympy(final_exp.code).expr.xreplace(substitutions)
new_str_expr = utils.sympy2str(new_sympy_expr)
final_exp._substituted_code = new_str_expr

def get_f_expressions(self, substitute_vars=None):
if self.f_expr is None:
return []
self._substitute(self.f_expr, self.expressions, substitute_vars=substitute_vars)

return_expressions = []
# the derivative expression
dif_eq_code = self.f_expr.get_code(subs=True)
return_expressions.append(Expression(f'_df{self.var_name}_dt', dif_eq_code))
# needed variables
need_vars = tools.get_identifiers(dif_eq_code)
need_vars |= tools.get_identifiers(', '.join(self.returns))
# get the total return expressions
for expr in self.expressions[::-1]:
if expr.var_name in need_vars:
if not profile._substitute_equation or expr._substituted_code is None:
code = expr.code
else:
code = expr._substituted_code
return_expressions.append(Expression(expr.var_name, code))
need_vars |= tools.get_identifiers(code)
return return_expressions[::-1]

def get_g_expressions(self):
if self.is_functional_noise:
return_expressions = []
# the derivative expression
eq_code = self.g_expr.get_code(subs=True)
return_expressions.append(Expression(f'_dg{self.var_name}_dt', eq_code))
# needed variables
need_vars = tools.get_identifiers(eq_code)
# get the total return expressions
for expr in self.expressions[::-1]:
if expr.var_name in need_vars:
if not profile._substitute_equation or expr._substituted_code is None:
code = expr.code
else:
code = expr._substituted_code
return_expressions.append(Expression(expr.var_name, code))
need_vars |= tools.get_identifiers(code)
return return_expressions[::-1]
else:
return [Expression(f'_dg{self.var_name}_dt', self.g_expr.get_code(subs=True))]

def _replace_expressions(self, expressions, name, y_sub, t_sub=None):
"""Replace expressions of df part.

Parameters
----------
expressions : list, tuple
The list/tuple of expressions.
name : str
The name of the new expression.
y_sub : str
The new name of the variable "y".
t_sub : str, optional
The new name of the variable "t".

Returns
-------
list_of_expr : list
A list of expressions.
"""
return_expressions = []

# replacements
replacement = {self.var_name: y_sub}
if t_sub is not None:
replacement[self.t_name] = t_sub

# replace variables in expressions
for expr in expressions:
replace = False
identifiers = expr.identifiers
for repl_var in replacement.keys():
if repl_var in identifiers:
replace = True
break
if replace:
code = tools.word_replace(expr.code, replacement)
new_expr = Expression(f"{expr.var_name}_{name}", code)
return_expressions.append(new_expr)
replacement[expr.var_name] = new_expr.var_name
return return_expressions

def replace_f_expressions(self, name, y_sub, t_sub=None):
"""Replace expressions of df part.

Parameters
----------
name : str
The name of the new expression.
y_sub : str
The new name of the variable "y".
t_sub : str, optional
The new name of the variable "t".

Returns
-------
list_of_expr : list
A list of expressions.
"""
return self._replace_expressions(self.get_f_expressions(),
name=name, y_sub=y_sub, t_sub=t_sub)

def replace_g_expressions(self, name, y_sub, t_sub=None):
if self.is_functional_noise:
return self._replace_expressions(self.get_g_expressions(),
name=name, y_sub=y_sub, t_sub=t_sub)
else:
return []

@property
def is_multi_return(self):
return len(self.returns) > 0

@property
def is_stochastic(self):
if self.g_expr is not None:
try:
if eval(self.g_expr.code, self.func_scope) == 0.:
return False
except Exception as e:
pass
return True
else:
return False

@property
def is_functional_noise(self):
return self.g_type == constants.FUNCTIONAL_NOISE

@property
def stochastic_type(self):
if not self.is_stochastic:
return None
else:
pass

@property
def expr_names(self):
return [expr.var_name for expr in self.expressions]

+ 0
- 1128
brainpy/integration/integrator.py View File

@@ -1,1128 +0,0 @@
# -*- coding: utf-8 -*-

import numpy as np
import sympy

from .diff_equation import DiffEquation
from .utils import get_mapping_scope
from .utils import str2sympy
from .utils import sympy2str
from .. import backend
from .. import profile
from .. import tools
from ..errors import IntegratorError

__all__ = [
'get_integrator',
'Integrator',
'Euler',
'Heun',
'MidPoint',
'RK2',
'RK3',
'RK4',
'RK4Alternative',
'ExponentialEuler',
'MilsteinIto',
'MilsteinStra',
]


def get_integrator(method):
method = method.lower()

if method == 'euler':
return Euler
elif method == 'midpoint':
return MidPoint
elif method == 'heun':
return Heun
elif method == 'rk2':
return RK2
elif method == 'rk3':
return RK3
elif method == 'rk4':
return RK4
elif method == 'rk4_alternative':
return RK4Alternative
elif method == 'exponential':
return ExponentialEuler
elif method == 'milstein':
return MilsteinIto
elif method == 'milstein_ito':
return MilsteinIto
elif method == 'milstein_stra':
return MilsteinStra
else:
raise ValueError(f'Unknown method: {method}.')


class Integrator(object):
def __init__(self, diff_eq):
if not isinstance(diff_eq, DiffEquation):
if diff_eq.__class__.__name__ != 'function':
raise IntegratorError('"diff_eq" must be a function or an instance of DiffEquation .')
else:
diff_eq = DiffEquation(func=diff_eq)
self.diff_eq = diff_eq
self._update_code = None
self._update_func = None

def __call__(self, y0, t, *args):
return self._update_func(y0, t, *args)

def _compile(self):
# function arguments
func_args = ', '.join([f'_{self.py_func_name}_{arg}' for arg in self.diff_eq.func_args])

# function codes
func_code = f'def {self.py_func_name}({func_args}): \n'
func_code += tools.indent(self._update_code + '\n' + f'return _{self.py_func_name}_res')
tools.NoiseHandler.normal_pattern.sub(
tools.NoiseHandler.vector_replace_f, func_code)

# function scope
code_scopes = {'numpy': np}
for k_, v_ in self.code_scope.items():
if profile.is_jit() and callable(v_):
v_ = tools.numba_func(v_)
code_scopes[k_] = v_
code_scopes.update(get_mapping_scope())
code_scopes['_normal_like_'] = backend.normal_like

# function compilation
exec(compile(func_code, '', 'exec'), code_scopes)
func = code_scopes[self.py_func_name]
if profile.is_jit():
func = tools.jit(func)
self._update_func = func

@staticmethod
def get_integral_step(diff_eq, *args):
raise NotImplementedError

@property
def py_func_name(self):
return self.diff_eq.func_name

@property
def update_code(self):
return self._update_code

@property
def update_func(self):
return self._update_func

@property
def code_scope(self):
scope = self.diff_eq.func_scope
if profile.run_on_cpu():
scope['_normal_like_'] = backend.normal_like
return scope


class Euler(Integrator):
"""Forward Euler method. Also named as ``explicit_Euler``.

The simplest way for solving ordinary differential equations is "the
Euler method" by Press et al. (1992) [1]_ :

.. math::

y_{n+1} = y_n + f(y_n, t_n) \\Delta t

This formula advances a solution from :math:`y_n` to :math:`y_{n+1}=y_n+h`.
Note that the method increments a solution through an interval :math:`h`
while using derivative information from only the beginning of the interval.
As a result, the step's error is :math:`O(h^2)`.

For SDE equations, this approximation is a continuous time stochastic process that
satisfy the iterative scheme [1]_.

.. math::

Y_{n+1} = Y_n + f(Y_n)h_n + g(Y_n)\\Delta W_n

where :math:`n=0,1, \\cdots , N-1`, :math:`Y_0=x_0`, :math:`Y_n = Y(t_n)`,
:math:`h_n = t_{n+1} - t_n` is the step size,
:math:`\\Delta W_n = [W(t_{n+1}) - W(t_n)] \\sim N(0, h_n)=\\sqrt{h}N(0, 1)`
with :math:`W(t_0) = 0`.

For simplicity, we rewrite the above equation into

.. math::

Y_{n+1} = Y_n + f_n h + g_n \\Delta W_n

As the order of convergence for the Euler-Maruyama method is low (strong order of
convergence 0.5, weak order of convergence 1), the numerical results are inaccurate
unless a small step size is used. By adding one more term from the stochastic
Taylor expansion, one obtains a 1.0 strong order of convergence scheme known
as *Milstein scheme* [2]_.

Parameters
----------
diff_eq : DiffEquation, callable
The differential equation.

Returns
-------
func : callable
The one-step numerical integrator function.

References
----------
.. [1] W. H.; Flannery, B. P.; Teukolsky, S. A.; and Vetterling,
W. T. Numerical Recipes in FORTRAN: The Art of Scientific
Computing, 2nd ed. Cambridge, England: Cambridge University
Press, p. 710, 1992.
.. [2] U. Picchini, Sde toolbox: Simulation and estimation of stochastic
differential equations with matlab.
"""

def __init__(self, diff_eq):
super(Euler, self).__init__(diff_eq)
self._update_code = self.get_integral_step(diff_eq)
self._compile()

@staticmethod
def get_integral_step(diff_eq, *args):
dt = profile.get_dt()
var_name = diff_eq.var_name
func_name = diff_eq.func_name
var = sympy.Symbol(var_name, real=True)

# get code lines of df part
f_expressions = diff_eq.get_f_expressions()
code_lines = [str(expr) for expr in f_expressions]
dfdt = sympy.Symbol(f'_df{var_name}_dt')

# get code lines of dg part
if diff_eq.is_stochastic:
noise = f'_normal_like_({var_name})'
code_lines.append(f'_{var_name}_dW = {noise}')
code_lines.extend([str(expr) for expr in diff_eq.get_g_expressions()])
dgdt = sympy.Symbol(f'_{var_name}_dW') * sympy.Symbol(f'_dg{var_name}_dt')
else:
dgdt = 0

# update expression
update = var + dfdt * dt + sympy.sqrt(dt) * dgdt
code_lines.append(f'{var_name} = {sympy2str(update)}')

# multiple returns
return_expr = ', '.join([var_name] + diff_eq.returns)
code_lines.append(f'_{func_name}_res = {return_expr}')

# final
code = '\n'.join(code_lines)
subs_dict = {arg: f'_{diff_eq.func_name}_{arg}' for arg in
diff_eq.func_args + diff_eq.expr_names}
code = tools.word_replace(code, subs_dict)
return code


class RK2(Integrator):
"""Parametric second-order Runge-Kutta (RK2). Also named as ``RK2``.

It is given in parametric form by [3]_ .

.. math::

k_1 &= f(y_n, t_n) \\\\
k_2 &= f(y_n + \\beta \\Delta t k_1, t_n + \\beta \\Delta t) \\\\
y_{n+1} &= y_n + \\Delta t [(1-\\frac{1}{2\\beta})k_1+\\frac{1}{2\\beta}k_2]

Parameters
----------
diff_eq : DiffEquation
The differential equation.
beta : float
Popular choices for 'beta':
1/2 : explicit midpoint method
2/3 : Ralston's method
1 : Heun's method, also known as the explicit trapezoid rule

Returns
-------
func : callable
The one-step numerical integrator function.

References
----------
.. [3] https://lpsa.swarthmore.edu/NumInt/NumIntSecond.html

See Also
--------
Heun, MidPoint
"""

def __init__(self, diff_eq, beta=2 / 3):
super(RK2, self).__init__(diff_eq)
self.beta = beta
self._update_code = self.get_integral_step(diff_eq, beta)
self._compile()

@staticmethod
def get_integral_step(diff_eq, beta=2 / 3):
dt = profile.get_dt()
t_name = diff_eq.t_name
var_name = diff_eq.var_name
func_name = diff_eq.func_name
var = sympy.Symbol(var_name, real=True)

# get code lines of k1 df part
k1_expressions = diff_eq.get_f_expressions(substitute_vars=None)
code_lines = [str(expr) for expr in k1_expressions[:-1]]
code_lines.append(f'_df{var_name}_dt_k1 = {k1_expressions[-1].code}')

# k1 -> k2 increment
y_1_to_2 = f'_{func_name}_{var_name}_k1_to_k2'
t_1_to_2 = f'_{func_name}_t_k1_to_k2'
code_lines.append(f'{y_1_to_2} = {var_name} + {beta * dt} * _df{var_name}_dt_k1')
code_lines.append(f'{t_1_to_2} = {t_name} + {beta * dt}')

# get code lines of k2 df part
k2_expressions = diff_eq.replace_f_expressions('k2', y_sub=y_1_to_2, t_sub=t_1_to_2)
if len(k2_expressions):
code_lines.extend([str(expr) for expr in k2_expressions[:-1]])
code_lines.append(f'_df{var_name}_dt_k2 = {k2_expressions[-1].code}')

# final dt part
dfdt = sympy.Symbol(f'_df{var_name}_dt')
if len(k2_expressions):
coefficient2 = 1 / (2 * beta)
coefficient1 = 1 - coefficient2
code_lines.append(
f'{dfdt.name} = {coefficient1} * _df{var_name}_dt_k1 + {coefficient2} * _df{var_name}_dt_k2')
else:
code_lines.append(f'{dfdt.name} = _df{var_name}_dt_k1')

# get code lines of dg part
dgdt = 0
if diff_eq.is_stochastic:
if not np.all(diff_eq.g_value == 0.):
raise NotImplementedError('RK2 currently doesn\'t support SDE.')

# update expression
update = var + dfdt * dt + sympy.sqrt(dt) * dgdt
code_lines.append(f'{var_name} = {sympy2str(update)}')

# multiple returns
return_expr = ', '.join([var_name] + diff_eq.returns)
code_lines.append(f'_{func_name}_res = {return_expr}')

# final
code = '\n'.join(code_lines)
subs_dict = {arg: f'_{diff_eq.func_name}_{arg}' for arg in
diff_eq.func_args + diff_eq.expr_names}
code = tools.word_replace(code, subs_dict)
return code


class Heun(Integrator):
"""Two-stage method for numerical integrator.

For ODE, please see "RK2".

For stochastic Stratonovich integral, the Heun algorithm is given by,
according to paper [4]_ [5]_.

.. math::
Y_{n+1} &= Y_n + f_n h + {1 \\over 2}[g_n + g(\\overline{Y}_n)] \\Delta W_n

\\overline{Y}_n &= Y_n + g_n \\Delta W_n


Or, it is written as

.. math::

Y_1 &= y_n + f(y_n)h + g_n \\Delta W_n

y_{n+1} &= y_n + {1 \over 2}[f(y_n) + f(Y_1)]h + {1 \\over 2} [g(y_n) + g(Y_1)] \\Delta W_n

Parameters
----------
diff_eq : DiffEquation
The differential equation.

Returns
-------
func : callable
The one-step numerical integrator function.

References
----------
.. [4] H. Gilsing and T. Shardlow, SDELab: A package for solving stochastic differential
equations in MATLAB, Journal of Computational and Applied Mathematics 205 (2007),
no. 2, 1002-1018.
.. [5] P.reversal_potential. Kloeden, reversal_potential. Platen, and H. Schurz, Numerical solution of SDE through computer
experiments, Springer, 1994.

See Also
--------
RK2, MidPoint, MilsteinStra
"""

def __init__(self, diff_eq):
super(Heun, self).__init__(diff_eq)
self._update_code = self.get_integral_step(diff_eq)
self._compile()

@staticmethod
def get_integral_step(diff_eq, *args):
if diff_eq.is_stochastic:
if diff_eq.is_functional_noise:
dt = profile.get_dt()
var_name = diff_eq.var_name
func_name = diff_eq.func_name
var = sympy.Symbol(var_name, real=True)

# k1 part #
# ------- #

# df
f_k1_expressions = diff_eq.get_f_expressions(substitute_vars=None)
code_lines = [str(expr) for expr in f_k1_expressions[:-1]]
code_lines.append(f'_df{var_name}_dt_k1 = {f_k1_expressions[-1].code}')

# dg
dW_sb = sympy.Symbol(f'_{var_name}_dW')
noise = f'_normal_like_({var_name})'
code_lines.append(f'{dW_sb.name} = sqrt({dt}) * {noise}')
g_k1_expressions = diff_eq.get_g_expressions()
code_lines.extend([str(expr) for expr in g_k1_expressions[:-1]])
code_lines.append(f'_dg{var_name}_dt_k1 = {g_k1_expressions[-1].code}')

# k1
code_lines.append(f'_{func_name}_k1 = {var_name} + _df{var_name}_dt_k1 * {dt} + '
f'_dg{var_name}_dt_k1 * {dW_sb.name}')

# k2 part #
# ------- #

# df
dfdt = sympy.Symbol(f'_df{var_name}_dt')
f_k2_expressions = diff_eq.replace_f_expressions('k2', y_sub=f'_{func_name}_k1')
if len(f_k2_expressions):
code_lines.extend([str(expr) for expr in f_k2_expressions[:-1]])
code_lines.append(f'_df{var_name}_dt_k2 = {f_k2_expressions[-1].code}')
code_lines.append(f'{dfdt.name} = (_df{var_name}_dt_k1 + _df{var_name}_dt_k2) / 2')
else:
code_lines.append(f'{dfdt.name} = _df{var_name}_dt_k1')

# dg
dgdt = sympy.Symbol(f'_dg{var_name}_dt')
g_k2_expressions = diff_eq.replace_f_expressions('k2', y_sub=f'_{func_name}_k1')
if len(g_k2_expressions):
code_lines.extend([str(expr) for expr in g_k2_expressions[:-1]])
code_lines.append(f'_dg{var_name}_dt_k2 = {g_k2_expressions[-1].code}')
code_lines.append(f'{dgdt.name} = (_dg{var_name}_dt_k1 + _dg{var_name}_dt_k2) / 2')
else:
code_lines.append(f'{dgdt.name} = _dg{var_name}_dt_k1')

# update expression
update = var + dfdt * dt + dgdt * dW_sb
code_lines.append(f'{var_name} = {sympy2str(update)}')

# multiple returns
return_expr = ', '.join([var_name] + diff_eq.returns)
code_lines.append(f'_{func_name}_res = {return_expr}')

# final
code = '\n'.join(code_lines)
subs_dict = {arg: f'_{diff_eq.func_name}_{arg}' for arg in
diff_eq.func_args + diff_eq.expr_names}
code = tools.word_replace(code, subs_dict)
return code
else:
return Euler.get_integral_step(diff_eq)
else:
return RK2.get_integral_step(diff_eq, 1.0)


class MidPoint(Integrator):
"""Explicit midpoint Euler method. Also named as ``modified_Euler``.

Parameters
----------
diff_eq : DiffEquation
The differential equation.

Returns
-------
func : callable
The one-step numerical integrator function.

See Also
--------
RK2, Heun
"""

def __init__(self, diff_eq):
super(MidPoint, self).__init__(diff_eq)
self._update_code = self.get_integral_step(diff_eq)
self._compile()

@staticmethod
def get_integral_step(diff_eq, *args):
if diff_eq.is_stochastic:
raise NotImplementedError
else:
return RK2.get_integral_step(diff_eq, 0.5)


class RK3(Integrator):
"""Kutta's third-order method (commonly known as RK3).
Also named as ``RK3`` [6]_ [7]_ [8]_ .

.. math::

k_1 &= f(y_n, t_n) \\\\
k_2 &= f(y_n + \\frac{\\Delta t}{2}k_1, tn+\\frac{\\Delta t}{2}) \\\\
k_3 &= f(y_n -\\Delta t k_1 + 2\\Delta t k_2, t_n + \\Delta t) \\\\
y_{n+1} &= y_{n} + \\frac{\\Delta t}{6}(k_1 + 4k_2+k_3)

Parameters
----------
diff_eq : DiffEquation
The differential equation.

Returns
-------
func : callable
The one-step numerical integrator function.

References
----------
.. [6] http://mathworld.wolfram.com/Runge-KuttaMethod.html
.. [7] https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods
.. [8] https://zh.wikipedia.org/wiki/龙格-库塔法

"""

def __init__(self, diff_eq):
super(RK3, self).__init__(diff_eq)
self._update_code = self.get_integral_step(diff_eq)
self._compile()

@staticmethod
def get_integral_step(diff_eq, *args):
dt = profile.get_dt()
t_name = diff_eq.t_name
var_name = diff_eq.var_name
func_name = diff_eq.func_name
var = sympy.Symbol(var_name, real=True)

# get code lines of k1 df part
k1_expressions = diff_eq.get_f_expressions(substitute_vars=None)
code_lines = [str(expr) for expr in k1_expressions[:-1]]
code_lines.append(f'_df{var_name}_dt_k1 = {k1_expressions[-1].code}')

# k1 -> k2 increment
y_1_to_2 = f'_{func_name}_{var_name}_k1_to_k2'
t_1_to_2 = f'_{func_name}_t_k1_to_k2'
code_lines.append(f'{y_1_to_2} = {var_name} + {dt / 2} * _df{var_name}_dt_k1')
code_lines.append(f'{t_1_to_2} = {t_name} + {dt / 2}')

# get code lines of k2 df part
k2_expressions = diff_eq.replace_f_expressions('k2', y_sub=y_1_to_2, t_sub=t_1_to_2)

dfdt = sympy.Symbol(f'_df{var_name}_dt')
if len(k2_expressions):
code_lines.extend([str(expr) for expr in k2_expressions[:-1]])
code_lines.append(f'_df{var_name}_dt_k2 = {k2_expressions[-1].code}')

# get code lines of k3 df part
y_1_to_3 = f'_{func_name}_{var_name}_k1_to_k3'
t_1_to_3 = f'_{func_name}_t_k1_to_k3'
code_lines.append(f'{y_1_to_3} = {var_name} - {dt} * _df{var_name}_dt_k1 + {2 * dt} * _df{var_name}_dt_k2')
code_lines.append(f'{t_1_to_3} = {t_name} + {dt}')
k3_expressions = diff_eq.replace_f_expressions('k3', y_sub=y_1_to_3, t_sub=t_1_to_3)
code_lines.extend([str(expr) for expr in k3_expressions[:-1]])
code_lines.append(f'_df{var_name}_dt_k3 = {k3_expressions[-1].code}')

# final df part
code_lines.append(f'{dfdt.name} = (_df{var_name}_dt_k1 + '
f'4 * _df{var_name}_dt_k2 + _df{var_name}_dt_k3) / 6')
else:
# final df part
code_lines.append(f'{dfdt.name} = _df{var_name}_dt_k1')

# get code lines of dg part
dgdt = 0
if diff_eq.is_stochastic:
if not np.all(diff_eq.g_value == 0.):
raise NotImplementedError('RK3 currently doesn\'t support SDE.')

# update expression
update = var + dfdt * dt + sympy.sqrt(dt) * dgdt
code_lines.append(f'{var_name} = {sympy2str(update)}')

# multiple returns
return_expr = ', '.join([var_name] + diff_eq.returns)
code_lines.append(f'_{func_name}_res = {return_expr}')

# final
code = '\n'.join(code_lines)
subs_dict = {arg: f'_{diff_eq.func_name}_{arg}' for arg in
diff_eq.func_args + diff_eq.expr_names}
code = tools.word_replace(code, subs_dict)
return code


class RK4(Integrator):
"""Fourth-order Runge-Kutta (RK4) [9]_ [10]_ [11]_ .

.. math::

k_1 &= f(y_n, t_n) \\\\
k_2 &= f(y_n + \\frac{\\Delta t}{2}k_1, t_n + \\frac{\\Delta t}{2}) \\\\
k_3 &= f(y_n + \\frac{\\Delta t}{2}k_2, t_n + \\frac{\\Delta t}{2}) \\\\
k_4 &= f(y_n + \\Delta t k_3, t_n + \\Delta t) \\\\
y_{n+1} &= y_n + \\frac{\\Delta t}{6}(k_1 + 2*k_2 + 2* k_3 + k_4)

Parameters
----------
diff_eq : DiffEquation
The differential equation.

Returns
-------
func : callable
The one-step numerical integrator function.

References
----------
.. [9] http://mathworld.wolfram.com/Runge-KuttaMethod.html
.. [10] https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods
.. [11] https://zh.wikipedia.org/wiki/龙格-库塔法

"""

def __init__(self, diff_eq):
super(RK4, self).__init__(diff_eq)
self._update_code = self.get_integral_step(diff_eq)
self._compile()

@staticmethod
def get_integral_step(diff_eq, *args):
dt = profile.get_dt()
t_name = diff_eq.t_name
var_name = diff_eq.var_name
func_name = diff_eq.func_name
var = sympy.Symbol(var_name, real=True)

# get code lines of k1 df part
k1_expressions = diff_eq.get_f_expressions(substitute_vars=None)
code_lines = [str(expr) for expr in k1_expressions[:-1]]
code_lines.append(f'_df{var_name}_dt_k1 = {k1_expressions[-1].code}')

# k1 -> k2 increment
y_1_to_2 = f'_{func_name}_{var_name}_k1_to_k2'
t_1_to_2 = f'_{func_name}_t_k1_to_k2'
code_lines.append(f'{y_1_to_2} = {var_name} + {dt / 2} * _df{var_name}_dt_k1')
code_lines.append(f'{t_1_to_2} = {t_name} + {dt / 2}')

# get code lines of k2 df part
k2_expressions = diff_eq.replace_f_expressions('k2', y_sub=y_1_to_2, t_sub=t_1_to_2)

dfdt = sympy.Symbol(f'_df{var_name}_dt')
if len(k2_expressions):
code_lines.extend([str(expr) for expr in k2_expressions[:-1]])
code_lines.append(f'_df{var_name}_dt_k2 = {k2_expressions[-1].code}')

# get code lines of k3 df part
y_2_to_3 = f'_{func_name}_{var_name}_k2_to_k3'
t_2_to_3 = f'_{func_name}_t_k2_to_k3'
code_lines.append(f'{y_2_to_3} = {var_name} + {dt / 2} * _df{var_name}_dt_k2')
code_lines.append(f'{t_2_to_3} = {t_name} + {dt / 2}')
k3_expressions = diff_eq.replace_f_expressions('k3', y_sub=y_2_to_3, t_sub=t_2_to_3)
code_lines.extend([str(expr) for expr in k3_expressions[:-1]])
code_lines.append(f'_df{var_name}_dt_k3 = {k3_expressions[-1].code}')

# get code lines of k4 df part
y_3_to_4 = f'_{func_name}_{var_name}_k3_to_k4'
t_3_to_4 = f'_{func_name}_t_k3_to_k4'
code_lines.append(f'{y_3_to_4} = {var_name} + {dt} * _df{var_name}_dt_k3')
code_lines.append(f'{t_3_to_4} = {t_name} + {dt}')
k4_expressions = diff_eq.replace_f_expressions('k4', y_sub=y_3_to_4, t_sub=t_3_to_4)
code_lines.extend([str(expr) for expr in k4_expressions[:-1]])
code_lines.append(f'_df{var_name}_dt_k4 = {k4_expressions[-1].code}')

# final df part
code_lines.append(f'{dfdt.name} = (_df{var_name}_dt_k1 + 2 * _df{var_name}_dt_k2 + '
f'2 * _df{var_name}_dt_k3 + _df{var_name}_dt_k4) / 6')
else:
# final df part
code_lines.append(f'{dfdt.name} = _df{var_name}_dt_k1')

# get code lines of dg part
dgdt = 0
if diff_eq.is_stochastic:
if not np.all(diff_eq.g_value == 0.):
raise NotImplementedError('RK4 currently doesn\'t support SDE.')

# update expression
update = var + dfdt * dt + sympy.sqrt(dt) * dgdt
code_lines.append(f'{var_name} = {sympy2str(update)}')

# multiple returns
return_expr = ', '.join([var_name] + diff_eq.returns)
code_lines.append(f'_{func_name}_res = {return_expr}')

# final
code = '\n'.join(code_lines)
subs_dict = {arg: f'_{diff_eq.func_name}_{arg}' for arg in
diff_eq.func_args + diff_eq.expr_names}
code = tools.word_replace(code, subs_dict)
return code


class RK4Alternative(Integrator):
"""An alternative of fourth-order Runge-Kutta method.
Also named as ``RK4_alternative`` ("3/8" rule).

It is a less often used fourth-order
explicit RK method, and was also proposed by Kutta [12]_:

.. math::

k_1 &= f(y_n, t_n) \\\\
k_2 &= f(y_n + \\frac{\\Delta t}{3}k_1, t_n + \\frac{\\Delta t}{3}) \\\\
k_3 &= f(y_n - \\frac{\\Delta t}{3}k_1 + \\Delta t k_2, t_n + \\frac{2 \\Delta t}{3}) \\\\
k_4 &= f(y_n + \\Delta t k_1 - \\Delta t k_2 + \\Delta t k_3, t_n + \\Delta t) \\\\
y_{n+1} &= y_n + \\frac{\\Delta t}{8}(k_1 + 3*k_2 + 3* k_3 + k_4)

Parameters
----------
diff_eq : DiffEquation
The differential equation.

Returns
-------
func : callable
The one-step numerical integrator function.

References
----------

.. [12] https://en.wikipedia.org/wiki/List_of_Runge%E2%80%93Kutta_methods
"""

def __init__(self, diff_eq):
super(RK4Alternative, self).__init__(diff_eq)
self._update_code = self.get_integral_step(diff_eq)
self._compile()

@staticmethod
def get_integral_step(diff_eq, *args):
dt = profile.get_dt()
t_name = diff_eq.t_name
var_name = diff_eq.var_name
func_name = diff_eq.func_name
var = sympy.Symbol(var_name, real=True)

# get code lines of k1 df part
k1_expressions = diff_eq.get_f_expressions(substitute_vars=None)
code_lines = [str(expr) for expr in k1_expressions[:-1]]
code_lines.append(f'_df{var_name}_dt_k1 = {k1_expressions[-1].code}')

# k1 -> k2 increment
y_1_to_2 = f'_{func_name}_{var_name}_k1_to_k2'
t_1_to_2 = f'_{func_name}_t_k1_to_k2'
code_lines.append(f'{y_1_to_2} = {var_name} + {dt / 3} * _df{var_name}_dt_k1')
code_lines.append(f'{t_1_to_2} = {t_name} + {dt / 3}')

# get code lines of k2 df part
k2_expressions = diff_eq.replace_f_expressions('k2', y_sub=y_1_to_2, t_sub=t_1_to_2)

dfdt = sympy.Symbol(f'_df{var_name}_dt')
if len(k2_expressions):
code_lines.extend([str(expr) for expr in k2_expressions[:-1]])
code_lines.append(f'_df{var_name}_dt_k2 = {k2_expressions[-1].code}')

# get code lines of k3 df part
y_1_to_3 = f'_{func_name}_{var_name}_k1_to_k3'
t_1_to_3 = f'_{func_name}_t_k1_to_k3'
code_lines.append(f'{y_1_to_3} = {var_name} - {dt / 3} * _df{var_name}_dt_k1 + {dt} * _df{var_name}_dt_k2')
code_lines.append(f'{t_1_to_3} = {t_name} + {dt * 2 / 3}')
k3_expressions = diff_eq.replace_f_expressions('k3', y_sub=y_1_to_3, t_sub=t_1_to_3)
code_lines.extend([str(expr) for expr in k3_expressions[:-1]])
code_lines.append(f'_df{var_name}_dt_k3 = {k3_expressions[-1].code}')

# get code lines of k4 df part
y_1_to_4 = f'_{func_name}_{var_name}_k1_to_k4'
t_1_to_4 = f'_{func_name}_t_k1_to_k4'
code_lines.append(f'{y_1_to_4} = {var_name} + {dt} * _df{var_name}_dt_k1 - {dt} * _df{var_name}_dt_k2'
f'+ {dt} * _df{var_name}_dt_k3')
code_lines.append(f'{t_1_to_4} = {t_name} + {dt}')
k4_expressions = diff_eq.replace_f_expressions('k4', y_sub=y_1_to_4, t_sub=t_1_to_4)
code_lines.extend([str(expr) for expr in k4_expressions[:-1]])
code_lines.append(f'_df{var_name}_dt_k4 = {k4_expressions[-1].code}')

# final df part
code_lines.append(f'{dfdt.name} = (_df{var_name}_dt_k1 + 3 * _df{var_name}_dt_k2 + '
f'3 * _df{var_name}_dt_k3 + _df{var_name}_dt_k4) / 8')
else:
# final df part
code_lines.append(f'{dfdt.name} = _df{var_name}_dt_k1')

# get code lines of dg part
dgdt = 0
if diff_eq.is_stochastic:
if not np.all(diff_eq.g_value == 0.):
raise NotImplementedError('RK4 currently doesn\'t support SDE.')

# update expression
update = var + dfdt * dt + sympy.sqrt(dt) * dgdt
code_lines.append(f'{var_name} = {sympy2str(update)}')

# multiple returns
return_expr = ', '.join([var_name] + diff_eq.returns)
code_lines.append(f'_{func_name}_res = {return_expr}')

# final
code = '\n'.join(code_lines)
subs_dict = {arg: f'_{diff_eq.func_name}_{arg}' for arg in
diff_eq.func_args + diff_eq.expr_names}
code = tools.word_replace(code, subs_dict)
return code


class ExponentialEuler(Integrator):
"""First order, explicit exponential Euler method.

For an ODE equation of the form

.. math::

y^{\\prime}=f(y), \quad y(0)=y_{0}

its schema is given by

.. math::

y_{n+1}= y_{n}+h \\varphi(hA) f (y_{n})

where :math:`A=f^{\prime}(y_{n})` and :math:`\\varphi(z)=\\frac{e^{z}-1}{z}`.

For linear ODE system: :math:`y^{\\prime} = Ay + B`,
the above equation is equal to

.. math::

y_{n+1}= y_{n}e^{hA}-B/A(1-e^{hA})

For a SDE equation of the form

.. math::

d y=(Ay+ F(y))dt + g(y)dW(t) = f(y)dt + g(y)dW(t), \\quad y(0)=y_{0}

its schema is given by [16]_

.. math::

y_{n+1} & =e^{\\Delta t A}(y_{n}+ g(y_n)\\Delta W_{n})+\\varphi(\\Delta t A) F(y_{n}) \\Delta t \\\\
&= y_n + \\Delta t \\varphi(\\Delta t A) f(y) + e^{\\Delta t A}g(y_n)\\Delta W_{n}

where :math:`\\varphi(z)=\\frac{e^{z}-1}{z}`.

Parameters
----------
diff_eq : DiffEquation
The differential equation.

Returns
-------
func : callable
The one-step numerical integrator function.

References
----------
.. [16] Erdoğan, Utku, and Gabriel J. Lord. "A new class of exponential integrators for stochastic
differential equations with multiplicative noise." arXiv preprint arXiv:1608.07096 (2016).
"""

def __init__(self, diff_eq):
super(ExponentialEuler, self).__init__(diff_eq)
self._update_code = self.get_integral_step(diff_eq)
self._compile()

@staticmethod
def get_integral_step(diff_eq, *args):
dt = profile.get_dt()
f_expressions = diff_eq.get_f_expressions(substitute_vars=diff_eq.var_name)

# code lines
code_lines = [str(expr) for expr in f_expressions[:-1]]

# get the linear system using sympy
f_res = f_expressions[-1]
df_expr = str2sympy(f_res.code).expr.expand()
s_df = sympy.Symbol(f"{f_res.var_name}")
code_lines.append(f'{s_df.name} = {sympy2str(df_expr)}')
var = sympy.Symbol(diff_eq.var_name, real=True)

# get df part
s_linear = sympy.Symbol(f'_{diff_eq.var_name}_linear')
s_linear_exp = sympy.Symbol(f'_{diff_eq.var_name}_linear_exp')
s_df_part = sympy.Symbol(f'_{diff_eq.var_name}_df_part')
if df_expr.has(var):
# linear
linear = sympy.collect(df_expr, var, evaluate=False)[var]
code_lines.append(f'{s_linear.name} = {sympy2str(linear)}')
# linear exponential
linear_exp = sympy.exp(linear * dt)
code_lines.append(f'{s_linear_exp.name} = {sympy2str(linear_exp)}')
# df part
df_part = (s_linear_exp - 1) / s_linear * s_df
code_lines.append(f'{s_df_part.name} = {sympy2str(df_part)}')

else:
# linear exponential
code_lines.append(f'{s_linear_exp.name} = sqrt({dt})')
# df part
code_lines.append(f'{s_df_part.name} = {sympy2str(dt * s_df)}')

# get dg part
if diff_eq.is_stochastic:
# dW
noise = f'_normal_like_({diff_eq.var_name})'
code_lines.append(f'_{diff_eq.var_name}_dW = {noise}')
# expressions of the stochastic part
g_expressions = diff_eq.get_g_expressions()
code_lines.extend([str(expr) for expr in g_expressions[:-1]])
g_expr = g_expressions[-1].code
# get the dg_part
s_dg_part = sympy.Symbol(f'_{diff_eq.var_name}_dg_part')
code_lines.append(f'_{diff_eq.var_name}_dg_part = {g_expr} * _{diff_eq.var_name}_dW')
else:
s_dg_part = 0

# update expression
update = var + s_df_part + s_dg_part * s_linear_exp

# The actual update step
code_lines.append(f'{diff_eq.var_name} = {sympy2str(update)}')
return_expr = ', '.join([diff_eq.var_name] + diff_eq.returns)
code_lines.append(f'_{diff_eq.func_name}_res = {return_expr}')

# final
code = '\n'.join(code_lines)
subs_dict = {arg: f'_{diff_eq.func_name}_{arg}' for arg in
diff_eq.func_args + diff_eq.expr_names}
code = tools.word_replace(code, subs_dict)
return code


class MilsteinIto(Integrator):
"""Itô stochastic integral. The derivative-free Milstein method is
an order 1.0 strong Taylor schema.

The following implementation approximates this derivative thanks to a
Runge-Kutta approach [13]_.

In Itô scheme, it is expressed as

.. math::

Y_{n+1} = Y_n + f_n h + g_n \\Delta W_n + {1 \\over 2\\sqrt{h}}
[g(\\overline{Y_n}) - g_n] [(\\Delta W_n)^2-h]

where :math:`\\overline{Y_n} = Y_n + f_n h + g_n \\sqrt{h}`.

Parameters
----------
diff_eq : DiffEquation
The differential equation.

Returns
-------
func : callable
The one-step numerical integrator function.

References
----------
.. [13] P.reversal_potential. Kloeden, reversal_potential. Platen, and H. Schurz, Numerical solution of SDE
through computer experiments, Springer, 1994.

"""

def __init__(self, diff_eq):
super(MilsteinIto, self).__init__(diff_eq)
self._update_code = self.get_integral_step(diff_eq)
self._compile()

@staticmethod
def get_integral_step(diff_eq, *args):
if diff_eq.is_stochastic:
if diff_eq.is_functional_noise:
g_dependent_on_var = diff_eq.replace_f_expressions('test', y_sub=f'test')
if len(g_dependent_on_var) == 0:
return Euler.get_integral_step(diff_eq)

dt = profile.get_dt()
var_name = diff_eq.var_name
func_name = diff_eq.func_name

# k1 part #
# ------- #

# df
f_k1_expressions = diff_eq.get_f_expressions(substitute_vars=None)
code_lines = [str(expr) for expr in f_k1_expressions] # _df{var_name}_dt

# dg
dW_sb = sympy.Symbol(f'_{var_name}_dW')
noise = f'_normal_like_({var_name})'
code_lines.append(f'{dW_sb.name} = sqrt({dt}) * {noise}')
g_k1_expressions = diff_eq.get_g_expressions()
code_lines.extend([str(expr) for expr in g_k1_expressions]) # _dg{var_name}_dt

# high order part #
# --------------- #
k1_expr = f'_{func_name}_k1 = {var_name} + _df{var_name}_dt * {dt} + ' \
f'_dg{var_name}_dt * sqrt({dt})'
high_order = sympy.Symbol(f'_dg{var_name}_high_order')
g_k2_expressions = diff_eq.replace_g_expressions('k2', y_sub=f'_{func_name}_k1')

# dg high order
if len(g_k2_expressions):
code_lines.append(k1_expr)
code_lines.extend([str(expr) for expr in g_k2_expressions[:-1]])
code_lines.append(f'_dg{var_name}_dt_k2 = {g_k2_expressions[-1].code}')
code_lines.append(f'{high_order.name} = 0.5 / sqrt({dt}) * '
f'(_dg{var_name}_dt_k2 - _dg{var_name}_dt) *'
f'({dW_sb.name} * {dW_sb.name} - {dt})')
code_lines.append(f'{var_name} = {var_name} + _df{var_name}_dt * {dt} + '
f'_dg{var_name}_dt * {dW_sb.name} + {high_order.name}')
else:
code_lines.append(f'{var_name} = {var_name} + _df{var_name}_dt * {dt} + '
f'_dg{var_name}_dt * {dW_sb.name}')

# multiple returns
return_expr = ', '.join([var_name] + diff_eq.returns)
code_lines.append(f'_{func_name}_res = {return_expr}')

# final
code = '\n'.join(code_lines)
subs_dict = {arg: f'_{diff_eq.func_name}_{arg}' for arg in
diff_eq.func_args + diff_eq.expr_names}
code = tools.word_replace(code, subs_dict)
return code

return Euler.get_integral_step(diff_eq)


class MilsteinStra(Integrator):
"""Heun two-stage stochastic numerical method for Stratonovich integral.

Use the Stratonovich Heun algorithm to integrate Stratonovich equation,
according to paper [14]_ [15]_.

.. math::
Y_{n+1} &= Y_n + f_n h + {1 \\over 2}[g_n + g(\\overline{Y}_n)] \\Delta W_n

\\overline{Y}_n &= Y_n + g_n \\Delta W_n


Or, it is written as

.. math::

Y_1 &= y_n + f(y_n)h + g_n \\Delta W_n

y_{n+1} &= y_n + {1 \over 2}[f(y_n) + f(Y_1)]h + {1 \\over 2} [g(y_n) + g(Y_1)] \\Delta W_n


Parameters
----------
diff_eq : DiffEquation
The differential equation.

Returns
-------
func : callable
The one-step numerical integrator function.

References
----------

.. [14] H. Gilsing and T. Shardlow, SDELab: A package for solving stochastic differential
equations in MATLAB, Journal of Computational and Applied Mathematics 205 (2007),
no. 2, 1002-1018.
.. [15] P.reversal_potential. Kloeden, reversal_potential. Platen, and H. Schurz, Numerical solution of SDE through computer
experiments, Springer, 1994.

See Also
--------
MilsteinIto

"""

def __init__(self, diff_eq):
super(MilsteinStra, self).__init__(diff_eq)
self._update_code = self.get_integral_step(diff_eq)
self._compile()

@staticmethod
def get_integral_step(diff_eq, *args):
if diff_eq.is_stochastic:
if diff_eq.is_functional_noise:
g_dependent_on_var = diff_eq.replace_f_expressions('test', y_sub=f'test')
if len(g_dependent_on_var) == 0:
return Euler.get_integral_step(diff_eq)

dt = profile.get_dt()
var_name = diff_eq.var_name
func_name = diff_eq.func_name

# k1 part #
# ------- #

# df
f_k1_expressions = diff_eq.get_f_expressions(substitute_vars=None)
code_lines = [str(expr) for expr in f_k1_expressions] # _df{var_name}_dt

# dg
dW_sb = sympy.Symbol(f'_{var_name}_dW')
noise = f'_normal_like_({var_name})'
code_lines.append(f'{dW_sb.name} = sqrt({dt}) * {noise}')
g_k1_expressions = diff_eq.get_g_expressions()
code_lines.extend([str(expr) for expr in g_k1_expressions]) # _dg{var_name}_dt

# high order part #
# --------------- #

k1_expr = f'_{func_name}_k1 = {var_name} + _df{var_name}_dt * {dt} + ' \
f'_dg{var_name}_dt * sqrt({dt})'
high_order = sympy.Symbol(f'_dg{var_name}_high_order')
g_k2_expressions = diff_eq.replace_g_expressions('k2', y_sub=f'_{func_name}_k1')
if len(g_k2_expressions):
code_lines.append(k1_expr)
code_lines.extend([str(expr) for expr in g_k2_expressions[:-1]])
code_lines.append(f'_dg{var_name}_dt_k2 = {g_k2_expressions[-1].code}')
code_lines.append(f'{high_order.name} = 0.5 / sqrt({dt}) * '
f'(_dg{var_name}_dt_k2 - _dg{var_name}_dt) *'
f'{dW_sb.name} * {dW_sb.name}')
code_lines.append(f'{var_name} = {var_name} + _df{var_name}_dt * {dt} + '
f'_dg{var_name}_dt * {dW_sb.name} + {high_order.name}')
else:
code_lines.append(f'{var_name} = {var_name} + _df{var_name}_dt * {dt} + '
f'_dg{var_name}_dt * {dW_sb.name}')

# multiple returns
return_expr = ', '.join([var_name] + diff_eq.returns)
code_lines.append(f'_{func_name}_res = {return_expr}')

# final
code = '\n'.join(code_lines)
subs_dict = {arg: f'_{diff_eq.func_name}_{arg}' for arg in
diff_eq.func_args + diff_eq.expr_names}
code = tools.word_replace(code, subs_dict)
return code

return Euler.get_integral_step(diff_eq)

+ 10
- 0
brainpy/integrators/__init__.py View File

@@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-

from . import dde
from . import fde
from . import ode
from . import sde
from .ast_analysis import *
from .constants import *
from .delay_vars import *
from .integrate_wrapper import *

+ 220
- 0
brainpy/integrators/ast_analysis.py View File

@@ -0,0 +1,220 @@
# -*- coding: utf-8 -*-

import ast
import inspect
from collections import OrderedDict

from brainpy import errors
from brainpy import tools

__all__ = [
'DiffEqReader',
'separate_variables',
]


class DiffEqReader(ast.NodeVisitor):
"""Read the code lines which defines the logic of a differential equation system.

Currently, DiffEqReader cannot handle the for loop, and if-else condition.
Also, it do not assign values by a functional call. Like this:

.. code-block:: python

func(a, b, c)

Instead, you should code like:

.. code-block:: python

c = func(a, b)

Therefore, this class only has minimum power to analyze differential
equations. For example, this class may help to automatically find out
the linear part of a differential equation, thus forming the
Exponential Euler numerical methods.
"""

def __init__(self):
self.code_lines = [] # list of str
self.variables = [] # list of list
self.returns = [] # list of str
self.rights = [] # list of str

@staticmethod
def visit_container(nodes):
variables = []
for var in nodes:
if isinstance(var, (ast.List, ast.Tuple)):
variables.extend(DiffEqReader.visit_container(var.elts))
elif isinstance(var, ast.Name):
variables.extend(var.id)
else:
raise ValueError(f'Unknown target type: {var}')
return variables

def visit_Assign(self, node):
variables = []
for target in node.targets:
if isinstance(target, (ast.List, ast.Tuple)):
variables.extend(self.visit_container(target.elts))
elif isinstance(target, ast.Name):
variables.append(target.id)
else:
raise ValueError(f'Unknown target type: {target}')
self.variables.append(variables)
self.code_lines.append(tools.ast2code(ast.fix_missing_locations(node)))
self.rights.append(tools.ast2code(ast.fix_missing_locations(node.value)))
return node

def visit_AugAssign(self, node):
var = node.target.id
self.variables.append(var)
expr = tools.ast2code(ast.fix_missing_locations(node))
self.code_lines.append(expr)
self.rights.append(tools.ast2code(ast.fix_missing_locations(node.value)))
return node

def visit_Return(self, node):
if isinstance(node.value, ast.Name):
self.returns.append(node.value.id)
elif isinstance(node.value, (ast.Tuple, ast.List)):
for var in node.value.elts:
if not (var, ast.Name):
raise errors.DiffEqError(f'Unknown return type: {node}')
self.returns.append(var.id)
else:
raise errors.DiffEqError(f'Unknown return type: {node}')
return node

def visit_AnnAssign(self, node):
raise errors.DiffEqError(f'Currently, {self.__class__.__name__} do not support an '
f'assignment with a type annotation.')

def visit_If(self, node):
raise errors.DiffEqError(f'Currently, {self.__class__.__name__} do not support to '
f'analyze "if-else" conditions in differential equation.')

def visit_IfExp(self, node):
raise errors.DiffEqError(f'Currently, {self.__class__.__name__} do not support to '
f'analyze "if-else" conditions in differential equation.')

def visit_For(self, node):
raise errors.DiffEqError(f'Currently, {self.__class__.__name__} do not support to '
f'analyze "for" loops in differential equation.')

def visit_While(self, node):
raise errors.DiffEqError(f'Currently, {self.__class__.__name__} do not support to '
f'analyze "while" loops in differential equation.')

def visit_Try(self, node):
raise errors.DiffEqError(f'Currently, {self.__class__.__name__} do not support to '
f'analyze "try" handler in differential equation.')

def visit_With(self, node):
raise errors.DiffEqError(f'Currently, {self.__class__.__name__} do not support to '
f'analyze "with" block in differential equation.')

def visit_Raise(self, node):
raise errors.DiffEqError(f'Currently, {self.__class__.__name__} do not support to '
f'analyze "raise" statement in differential equation.')

def visit_Delete(self, node):
raise errors.DiffEqError(f'Currently, {self.__class__.__name__} do not support to '
f'analyze "del" operation in differential equation.')


def separate_variables(func_or_code):
"""Separate the expressions in a differential equation for each variable.

For example, take the HH neuron model as an example:

>>> eq_code = '''
>>> def integral(m, h, t, Iext, V):
>>> alpha = 0.1 * (V + 40) / (1 - np.exp(-(V + 40) / 10))
>>> beta = 4.0 * np.exp(-(V + 65) / 18)
>>> dmdt = alpha * (1 - m) - beta * m
>>>
>>> alpha = 0.07 * np.exp(-(V + 65) / 20.)
>>> beta = 1 / (1 + np.exp(-(V + 35) / 10))
>>> dhdt = alpha * (1 - h) - beta * h
>>> return dmdt, dhdt
>>> '''
>>> analyser = DiffEqReader()
>>> analyser.visit(ast.parse(eq_code))
>>> separate_variables(returns=analyser.returns,
>>> variables=analyser.variables,
>>> right_exprs=analyser.rights,
>>> code_lines=analyser.code_lines)
{'dhdt': ['alpha = 0.07 * np.exp(-(V + 65) / 20.0)\n',
'beta = 1 / (1 + np.exp(-(V + 35) / 10))\n',
'dhdt = alpha * (1 - h) - beta * h\n'],
'dmdt': ['alpha = 0.1 * (V + 40) / (1 - np.exp(-(V + 40) / 10))\n',
'beta = 4.0 * np.exp(-(V + 65) / 18)\n',
'dmdt = alpha * (1 - m) - beta * m\n']}

Parameters
----------
func_or_code : callable, str
The callable function or the function code.

Returns
-------
anlysis : dict
The expressions for each return variable.
"""
if callable(func_or_code):
func_or_code = tools.deindent(inspect.getsource(func_or_code))
assert isinstance(func_or_code, str)
analyser = DiffEqReader()
analyser.visit(ast.parse(func_or_code))

returns = analyser.returns
variables = analyser.variables
right_exprs = analyser.rights
code_lines = analyser.code_lines

return_requires = OrderedDict([(r, set(tools.get_identifiers(r))) for r in returns])
code_lines_for_returns = OrderedDict([(r, []) for r in returns])
variables_for_returns = OrderedDict([(r, []) for r in returns])
expressions_for_returns = OrderedDict([(r, []) for r in returns])

length = len(variables)
reverse_ids = list(reversed([i - length for i in range(length)]))
for r in code_lines_for_returns.keys():
for rid in reverse_ids:
dep = []
for v in variables[rid]:
if v in return_requires[r]:
dep.append(v)
if len(dep):
code_lines_for_returns[r].append(code_lines[rid])
variables_for_returns[r].append(variables[rid])
expr = right_exprs[rid]
expressions_for_returns[r].append(expr)
for d in dep:
return_requires[r].remove(d)
return_requires[r].update(tools.get_identifiers(expr))
for r in list(code_lines_for_returns.keys()):
code_lines_for_returns[r] = code_lines_for_returns[r][::-1]
variables_for_returns[r] = variables_for_returns[r][::-1]
expressions_for_returns[r] = expressions_for_returns[r][::-1]

analysis = tools.DictPlus(
code_lines_for_returns=code_lines_for_returns,
variables_for_returns=variables_for_returns,
expressions_for_returns=expressions_for_returns,
)
return analysis


# def dissect_diff_eq(func_or_code):
# if callable(func_or_code):
# func_or_code = tools.deindent(inspect.getsource(func_or_code))
# assert isinstance(func_or_code, str)
# analyser = DiffEqReader()
# analyser.visit(ast.parse(func_or_code))
# return separate_variables(returns=analyser.returns,
# variables=analyser.variables,
# right_exprs=analyser.rights,
# code_lines=analyser.code_lines)

+ 95
- 0
brainpy/integrators/constants.py View File

@@ -0,0 +1,95 @@
# -*- coding: utf-8 -*-


__all__ = [
'SUPPORTED_VAR_TYPE',
'SCALAR_VAR',
'POPU_VAR',
'SYSTEM_VAR',

'SUPPORTED_WIENER_TYPE',
'SCALAR_WIENER',
'VECTOR_WIENER',

'SUPPORTED_SDE_TYPE',
'ITO_SDE',
'STRA_SDE',

'NAME_PREFIX',
]

# Ito SDE
# ---
#
ITO_SDE = 'Ito'

# Stratonovich SDE
# ---
#
STRA_SDE = 'Stratonovich'

SUPPORTED_SDE_TYPE = [
ITO_SDE,
STRA_SDE
]

# Scalar Wiener process
# ----
#
SCALAR_WIENER = 'scalar_wiener'

# Vector Wiener process
# ----
#
VECTOR_WIENER = 'vector_wiener'

SUPPORTED_WIENER_TYPE = [
SCALAR_WIENER,
VECTOR_WIENER
]

# Denotes each variable is a scalar variable
# -------
# For example:
#
# def derivative(a, b, t):
# ...
# return da, db
#
# The "a" and "b" are scalars: a=1, b=2
#
SCALAR_VAR = 'scalar'

# Denotes each variable is a homogeneous population
# -------
# For example:
#
# def derivative(a, b, t):
# ...
# return da, db
#
# The "a" and "b" are vectors or matrix:
# a = np.array([1,2]), b = np.array([3,4])
# or,
# a = np.array([[1,2], [2,1]]), b=np.array([[3,4], [4,3]])
#
POPU_VAR = 'population'

# Denotes each variable is a system
# ------
# For example, the above defined differential equations can be defined as:
#
# def derivative(x, t):
# a, b = x
# ...
# dx = np.array([da, db])
# return dx
SYSTEM_VAR = 'system'

SUPPORTED_VAR_TYPE = [
SCALAR_VAR,
POPU_VAR,
SYSTEM_VAR,
]

NAME_PREFIX = '_brainpy_numint_of_'

+ 1
- 0
brainpy/integrators/dde/__init__.py View File

@@ -0,0 +1 @@
# -*- coding: utf-8 -*-

+ 73
- 0
brainpy/integrators/delay_vars.py View File

@@ -0,0 +1,73 @@
# -*- coding: utf-8 -*-


import abc
import math

from brainpy import backend

__all__ = [
'AbstractDelay',
'ConstantDelay',
'VaryingDelay',
'NeutralDelay',
]


class AbstractDelay(abc.ABC):
def __setitem__(self, time, value):
pass

def __getitem__(self, time):
pass


class ConstantDelay(AbstractDelay):
def __init__(self, v0, delay_len, before_t0=0., t0=0., dt=None):
# size
self.size = backend.shape(v0)

# delay_len
self.delay_len = delay_len
self.dt = backend.get_dt() if dt is None else dt
self.num_delay = int(math.ceil(delay_len / self.dt))

# other variables
self._delay_in = self.num_delay - 1
self._delay_out = 0
self.current_time = t0

# before_t0
self.before_t0 = before_t0

# delay data
self.data = backend.zeros((self.num_delay + 1,) + self.size)
if callable(before_t0):
for i in range(self.num_delay):
self.data[i] = before_t0(t0 + (i - self.num_delay) * self.dt)
else:
self.data[:-1] = before_t0
self.data[-1] = v0

def __setitem__(self, time, value): # push
self.data[self._delay_in] = value
self.current_time = time

def __getitem__(self, time): # pull
diff = self.current_time - time
m = math.ceil(diff / self.dt)
return self.data[self._delay_out]

def update(self):
self._delay_in = (self._delay_in + 1) % self.num_delay
self._delay_out = (self._delay_out + 1) % self.num_delay


class VaryingDelay(AbstractDelay):
def __init__(self):
pass


class NeutralDelay(AbstractDelay):
def __init__(self):
pass

+ 1
- 0
brainpy/integrators/fde/__init__.py View File

@@ -0,0 +1 @@
# -*- coding: utf-8 -*-

+ 120
- 0
brainpy/integrators/integrate_wrapper.py View File

@@ -0,0 +1,120 @@
# -*- coding: utf-8 -*-

from . import ode
from . import sde

__all__ = [
'SUPPORTED_ODE_METHODS',
'SUPPORTED_SDE_METHODS',


'odeint',
'sdeint',
'ddeint',
'fdeint',

'set_default_odeint',
'get_default_odeint',
'set_default_sdeint',
'get_default_sdeint',
]

_DEFAULT_ODE_METHOD = 'euler'
_DEFAULT_SDE_METHOD = 'euler'
SUPPORTED_ODE_METHODS = [m for m in dir(ode) if not m.startswith('__') and callable(getattr(ode, m))]
SUPPORTED_SDE_METHODS = [m for m in dir(sde) if not m.startswith('__') and callable(getattr(sde, m))]


def _wrapper(f, method, module, **kwargs):
integrator = getattr(module, method)
return integrator(f, **kwargs)


def odeint(f=None, method=None, **kwargs):
if method is None:
method = _DEFAULT_ODE_METHOD
if method not in SUPPORTED_ODE_METHODS:
raise ValueError(f'Unknown ODE numerical method "{method}". Currently '
f'BrainPy only support: {SUPPORTED_ODE_METHODS}')

if f is None:
return lambda f: _wrapper(f, method=method, module=ode, **kwargs)
else:
return _wrapper(f, method=method, module=ode, **kwargs)


def sdeint(f=None, method=None, **kwargs):
if method is None:
method = _DEFAULT_SDE_METHOD
if method not in SUPPORTED_SDE_METHODS:
raise ValueError(f'Unknown SDE numerical method "{method}". Currently '
f'BrainPy only support: {SUPPORTED_SDE_METHODS}')

if f is None:
return lambda f: _wrapper(f, method=method, module=sde, **kwargs)
else:
return _wrapper(f, method=method, module=sde, **kwargs)


def ddeint():
raise NotImplementedError


def fdeint():
raise NotImplementedError


def set_default_odeint(method):
"""Set the default ODE numerical integrator method for differential equations.

Parameters
----------
method : str, callable
Numerical integrator method.
"""
if not isinstance(method, str):
raise ValueError(f'Only support string, not {type(method)}.')
if method not in SUPPORTED_ODE_METHODS:
raise ValueError(f'Unsupported ODE numerical method: {method}.')

global _DEFAULT_ODE_METHOD
_DEFAULT_ODE_METHOD = method


def get_default_odeint():
"""Get the default ODE numerical integrator method.

Returns
-------
method : str
The default numerical integrator method.
"""
return _DEFAULT_ODE_METHOD


def set_default_sdeint(method):
"""Set the default SDE numerical integrator method for differential equations.

Parameters
----------
method : str, callable
Numerical integrator method.
"""
if not isinstance(method, str):
raise ValueError(f'Only support string, not {type(method)}.')
if method not in SUPPORTED_SDE_METHODS:
raise ValueError(f'Unsupported SDE numerical method: {method}.')

global _DEFAULT_SDE_METHOD
_DEFAULT_SDE_METHOD = method


def get_default_sdeint():
"""Get the default ODE numerical integrator method.

Returns
-------
method : str
The default numerical integrator method.
"""
return _DEFAULT_SDE_METHOD

+ 10
- 0
brainpy/integrators/ode/__init__.py View File

@@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-

"""
Numerical methods for ordinary differential equations.
"""

from .rk_adaptive_methods import *
from .rk_methods import *
# from .other_methods import *


+ 18
- 0
brainpy/integrators/ode/exp_euler.py View File

@@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-

from brainpy import backend

__all__ = [
'exponential_euler',
]


def exponential_euler(f, return_linear_term=False):
dt = backend.get_dt()

def int_f(x, t, *args):
df, linear_part = f(x, t, *args)
y = x + (backend.exp(linear_part * dt) - 1) / linear_part * df
return y

return int_f

+ 340
- 0
brainpy/integrators/ode/rk_adaptive_methods.py View File

@@ -0,0 +1,340 @@
# -*- coding: utf-8 -*-

from brainpy import backend
from brainpy.integrators import constants
from .wrapper import adaptive_rk_wrapper

__all__ = [
'rkf45',
'rkf12',
'rkdp',
'ck',
'bs',
'heun_euler'
]


def _base(A, B1, B2, C, f=None, tol=None, adaptive=None,
dt=None, show_code=None, var_type=None):
"""

Parameters
----------
A :
B1 :
B2 :
C :
f :
tol :
adaptive :
dt :
show_code :
var_type :

Returns
-------

"""
adaptive = False if (adaptive is None) else adaptive
dt = backend.get_dt() if (dt is None) else dt
tol = 0.1 if tol is None else tol
show_code = False if tol is None else show_code
var_type = constants.POPU_VAR if var_type is None else var_type

if f is None:
return lambda f: adaptive_rk_wrapper(f, dt=dt, A=A, B1=B1, B2=B2, C=C, tol=tol,
adaptive=adaptive, show_code=show_code,
var_type=var_type)
else:
return adaptive_rk_wrapper(f, dt=dt, A=A, B1=B1, B2=B2, C=C, tol=tol,
adaptive=adaptive, show_code=show_code,
var_type=var_type)


def rkf45(f=None, tol=None, adaptive=None, dt=None, show_code=None, var_type=None):
"""The Runge–Kutta–Fehlberg method for ordinary differential equations.

The method presented in Fehlberg's 1969 paper has been dubbed the
RKF45 method, and is a method of order :math:`O(h^4)` with an error
estimator of order :math:`O(h^5)`. The novelty of Fehlberg's method is
that it is an embedded method from the Runge–Kutta family, meaning that
identical function evaluations are used in conjunction with each other
to create methods of varying order and similar error constants.

It has the characteristics of:

- method stage = 6
- method order = 5
- Butcher Tables:

.. math::

\\begin{array}{l|lllll}
0 & & & & & & \\\\
1 / 4 & 1 / 4 & & & & \\\\
3 / 8 & 3 / 32 & 9 / 32 & & \\\\
12 / 13 & 1932 / 2197 & -7200 / 2197 & 7296 / 2197 & \\\\
1 & 439 / 216 & -8 & 3680 / 513 & -845 / 4104 & & \\\\
1 / 2 & -8 / 27 & 2 & -3544 / 2565 & 1859 / 4104 & -11 / 40 & \\\\
\\hline & 16 / 135 & 0 & 6656 / 12825 & 28561 / 56430 & -9 / 50 & 2 / 55 \\\\
& 25 / 216 & 0 & 1408 / 2565 & 2197 / 4104 & -1 / 5 & 0
\\end{array}

References
----------

[1] https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta%E2%80%93Fehlberg_method
[2] Erwin Fehlberg (1969). Low-order classical Runge-Kutta formulas with step
size control and their application to some heat transfer problems . NASA
Technical Report 315.
https://ntrs.nasa.gov/api/citations/19690021375/downloads/19690021375.pdf

"""

A = [(), (0.25,), (0.09375, 0.28125),
('1932/2197', '-7200/2197', '7296/2197'),
('439/216', -8, '3680/513', '-845/4104'),
('-8/27', 2, '-3544/2565', '1859/4104', -0.275)]
B1 = ['16/135', 0, '6656/12825', '28561/56430', -0.18, '2/55']
B2 = ['25/216', 0, '1408/2565', '2197/4104', -0.2, 0]
C = [0, 0.25, 0.375, '12/13', 1, '1/3']

return _base(A=A, B1=B1, B2=B2, C=C, f=f, dt=dt, tol=tol,
adaptive=adaptive, show_code=show_code, var_type=var_type)


def rkf12(f=None, tol=None, adaptive=None, dt=None, show_code=None, var_type=None):
"""The Fehlberg RK1(2) method for ordinary differential equations.

The Fehlberg method has two methods of orders 1 and 2.

It has the characteristics of:

- method stage = 2
- method order = 1
- Butcher Tables:

.. math::

\\begin{array}{l|ll}
0 & & \\\\
1 / 2 & 1 / 2 & \\\\
1 & 1 / 256 & 255 / 256 & \\\\
\\hline & 1 / 512 & 255 / 256 & 1 / 512 \\\\
& 1 / 256 & 255 / 256 & 0
\\end{array}

References
----------

.. [1] Fehlberg, E. (1969-07-01). "Low-order classical Runge-Kutta
formulas with stepsize control and their application to some heat
transfer problems"

"""

A = [(), (0.5,), ('1/256', '255/256')]
B1 = ['1/512', '255/256', '1/512']
B2 = ['1/256', '255/256', 0]
C = [0, 0.5, 1]

return _base(A=A, B1=B1, B2=B2, C=C, f=f, dt=dt, tol=tol,
adaptive=adaptive, show_code=show_code, var_type=var_type)


def rkdp(f=None, tol=None, adaptive=None, dt=None, show_code=None, var_type=None):
"""The Dormand–Prince method for ordinary differential equations.

The DOPRI method, is an explicit method for solving ordinary differential equations
(Dormand & Prince 1980). The Dormand–Prince method has seven stages, but it uses only
six function evaluations per step because it has the FSAL (First Same As Last) property:
the last stage is evaluated at the same point as the first stage of the next step.
Dormand and Prince chose the coefficients of their method to minimize the error of
the fifth-order solution. This is the main difference with the Fehlberg method, which
was constructed so that the fourth-order solution has a small error. For this reason,
the Dormand–Prince method is more suitable when the higher-order solution is used to
continue the integration, a practice known as local extrapolation
(Shampine 1986; Hairer, Nørsett & Wanner 2008, pp. 178–179).

It has the characteristics of:

- method stage = 7
- method order = 5
- Butcher Tables:

.. math::

\\begin{array}{l|llllll}
0 & \\\\
1 / 5 & 1 / 5 & & & \\\\
3 / 10 & 3 / 40 & 9 / 40 & & & \\\\
4 / 5 & 44 / 45 & -56 / 15 & 32 / 9 & & \\\\
8 / 9 & 19372 / 6561 & -25360 / 2187 & 64448 / 6561 & -212 / 729 & \\\\
1 & 9017 / 3168 & -355 / 33 & 46732 / 5247 & 49 / 176 & -5103 / 18656 & \\\\
1 & 35 / 384 & 0 & 500 / 1113 & 125 / 192 & -2187 / 6784 & 11 / 84 & \\\\
\\hline & 35 / 384 & 0 & 500 / 1113 & 125 / 192 & -2187 / 6784 & 11 / 84 & 0 \\\\
& 5179 / 57600 & 0 & 7571 / 16695 & 393 / 640 & -92097 / 339200 & 187 / 2100 & 1 / 40
\\end{array}

References
----------

[1] https://en.wikipedia.org/wiki/Dormand%E2%80%93Prince_method
[2] Dormand, J. R.; Prince, P. J. (1980), "A family of embedded Runge-Kutta formulae",
Journal of Computational and Applied Mathematics, 6 (1): 19–26,
doi:10.1016/0771-050X(80)90013-3.
"""

A = [(), (0.2,), (0.075, 0.225),
('44/45', '-56/15', '32/9'),
('19372/6561', '-25360/2187', '64448/6561', '-212/729'),
('9017/3168', '-355/33', '46732/5247', '49/176', '-5103/18656'),
('35/384', 0, '500/1113', '125/192', '-2187/6784', '11/84')]
B1 = ['35/384', 0, '500/1113', '125/192', '-2187/6784', '11/84', 0]
B2 = ['5179/57600', 0, '7571/16695', '393/640', '-92097/339200', '187/2100', 0.025]
C = [0, 0.2, 0.3, 0.8, '8/9', 1, 1]

return _base(A=A, B1=B1, B2=B2, C=C, f=f, dt=dt, tol=tol,
adaptive=adaptive, show_code=show_code, var_type=var_type)


def ck(f=None, tol=None, adaptive=None, dt=None, show_code=None, var_type=None):
"""The Cash–Karp method for ordinary differential equations.

The Cash–Karp method was proposed by Professor Jeff R. Cash from Imperial College London
and Alan H. Karp from IBM Scientific Center. it uses six function evaluations to calculate
fourth- and fifth-order accurate solutions. The difference between these solutions is then
taken to be the error of the (fourth order) solution. This error estimate is very convenient
for adaptive stepsize integration algorithms.

It has the characteristics of:

- method stage = 6
- method order = 4
- Butcher Tables:

.. math::

\\begin{array}{l|lllll}
0 & & & & & & \\\\
1 / 5 & 1 / 5 & & & & & \\\\
3 / 10 & 3 / 40 & 9 / 40 & & & \\\\
3 / 5 & 3 / 10 & -9 / 10 & 6 / 5 & & \\\\
1 & -11 / 54 & 5 / 2 & -70 / 27 & 35 / 27 & & \\\\
7 / 8 & 1631 / 55296 & 175 / 512 & 575 / 13824 & 44275 / 110592 & 253 / 4096 & \\\\
\\hline & 37 / 378 & 0 & 250 / 621 & 125 / 594 & 0 & 512 / 1771 \\\\
& 2825 / 27648 & 0 & 18575 / 48384 & 13525 / 55296 & 277 / 14336 & 1 / 4
\\end{array}

References
----------

[1] https://en.wikipedia.org/wiki/Cash%E2%80%93Karp_method
[2] J. R. Cash, A. H. Karp. "A variable order Runge-Kutta method for initial value
problems with rapidly varying right-hand sides", ACM Transactions on Mathematical
Software 16: 201-222, 1990. doi:10.1145/79505.79507
"""

A = [(), (0.2,), (0.075, 0.225), (0.3, -0.9, 1.2),
('-11/54', 2.5, '-70/27', '35/27'),
('1631/55296', '175/512', '575/13824', '44275/110592', '253/4096')]
B1 = ['37/378', 0, '250/621', '125/594', 0, '512/1771']
B2 = ['2825/27648', 0, '18575/48384', '13525/55296', '277/14336', 0.25]
C = [0, 0.2, 0.3, 0.6, 1, 0.875]

return _base(A=A, B1=B1, B2=B2, C=C, f=f, dt=dt, tol=tol,
adaptive=adaptive, show_code=show_code, var_type=var_type)


def bs(f=None, tol=None, adaptive=None, dt=None, show_code=None, var_type=None):
"""The Bogacki–Shampine method for ordinary differential equations.

The Bogacki–Shampine method was proposed by Przemysław Bogacki and Lawrence F.
Shampine in 1989 (Bogacki & Shampine 1989). The Bogacki–Shampine method is a
Runge–Kutta method of order three with four stages with the First Same As Last
(FSAL) property, so that it uses approximately three function evaluations per
step. It has an embedded second-order method which can be used to implement adaptive step size.

It has the characteristics of:

- method stage = 4
- method order = 3
- Butcher Tables:

.. math::

\\begin{array}{l|lll}
0 & & & \\\\
1 / 2 & 1 / 2 & & \\\\
3 / 4 & 0 & 3 / 4 & \\\\
1 & 2 / 9 & 1 / 3 & 4 / 9 \\\\
\\hline & 2 / 9 & 1 / 3 & 4 / 90 \\\\
& 7 / 24 & 1 / 4 & 1 / 3 & 1 / 8
\\end{array}

References
----------

[1] https://en.wikipedia.org/wiki/Bogacki%E2%80%93Shampine_method
[2] Bogacki, Przemysław; Shampine, Lawrence F. (1989), "A 3(2) pair of Runge–Kutta
formulas", Applied Mathematics Letters, 2 (4): 321–325, doi:10.1016/0893-9659(89)90079-7
"""

A = [(), (0.5,), (0., 0.75), ('2/9', '1/3', '4/0'), ]
B1 = ['2/9', '1/3', '4/9', 0]
B2 = ['7/24', 0.25, '1/3', 0.125]
C = [0, 0.5, 0.75, 1]

return _base(A=A, B1=B1, B2=B2, C=C, f=f, dt=dt, tol=tol,
adaptive=adaptive, show_code=show_code, var_type=var_type)


def heun_euler(f=None, tol=None, adaptive=None, dt=None, show_code=None, var_type=None):
"""The Heun–Euler method for ordinary differential equations.

The simplest adaptive Runge–Kutta method involves combining Heun's method,
which is order 2, with the Euler method, which is order 1.

It has the characteristics of:

- method stage = 2
- method order = 1
- Butcher Tables:

.. math::

\\begin{array}{c|cc}
0&\\\\
1& 1 \\\\
\\hline
& 1/2& 1/2\\\\
& 1 & 0
\\end{array}

"""

A = [(), (1,)]
B1 = [0.5, 0.5]
B2 = [1, 0]
C = [0, 1]

return _base(A=A, B1=B1, B2=B2, C=C, f=f, dt=dt, tol=tol,
adaptive=adaptive, show_code=show_code, var_type=var_type)


def DOP853(f=None, tol=None, adaptive=None, dt=None, show_code=None, each_var_is_scalar=None):
"""The DOP853 method for ordinary differential equations.

DOP853 is an explicit Runge-Kutta method of order 8(5,3) due to Dormand & Prince
(with stepsize control and dense output).


References
----------

[1] E. Hairer, S.P. Norsett and G. Wanner, "Solving ordinary Differential Equations
I. Nonstiff Problems", 2nd edition. Springer Series in Computational Mathematics,
Springer-Verlag (1993).
[2] http://www.unige.ch/~hairer/software.html
"""
pass

+ 347
- 0
brainpy/integrators/ode/rk_methods.py View File

@@ -0,0 +1,347 @@
# -*- coding: utf-8 -*-


from brainpy import backend
from .wrapper import rk_wrapper
from .wrapper import wrapper_of_rk2

__all__ = [
'euler',
'midpoint',
'heun2',
'ralston2',
'rk2',
'rk3',
'heun3',
'ralston3',
'ssprk3',
'rk4',
'ralston4',
'rk4_38rule',
]


def _base(A, B, C, f, show_code, dt):
dt = backend.get_dt() if dt is None else dt
show_code = False if show_code is None else show_code

if f is None:
return lambda f: rk_wrapper(f, show_code=show_code, dt=dt, A=A, B=B, C=C)
else:
return rk_wrapper(f, show_code=show_code, dt=dt, A=A, B=B, C=C)


def euler(f=None, show_code=None, dt=None):
"""The Euler method is first order. The lack of stability
and accuracy limits its popularity mainly to use as a
simple introductory example of a numeric solution method.
"""
A = [(), ]
B = [1]
C = [0]
return _base(A=A, B=B, C=C, f=f, show_code=show_code, dt=dt)


def midpoint(f=None, show_code=None, dt=None):
"""midpoint method for ordinary differential equations.

The (explicit) midpoint method is a second-order method
with two stages.

It has the characteristics of:

- method stage = 2
- method order = 2
- Butcher Tables:

.. math::

\\begin{array}{c|cc}
0 & 0 & 0 \\\\
1 / 2 & 1 / 2 & 0 \\\\
\\hline & 0 & 1
\\end{array}

"""
A = [(), (0.5,)]
B = [0, 1]
C = [0, 0.5]
return _base(A=A, B=B, C=C, f=f, show_code=show_code, dt=dt)


def heun2(f=None, show_code=None, dt=None):
"""Heun's method for ordinary differential equations.

Heun's method is a second-order method with two stages.
It is also known as the explicit trapezoid rule, improved
Euler's method, or modified Euler's method.

It has the characteristics of:

- method stage = 2
- method order = 2
- Butcher Tables:

.. math::

\\begin{array}{c|cc}
0.0 & 0.0 & 0.0 \\\\
1.0 & 1.0 & 0.0 \\\\
\\hline & 0.5 & 0.5
\\end{array}

"""
A = [(), (1,)]
B = [0.5, 0.5]
C = [0, 1]
return _base(A=A, B=B, C=C, f=f, show_code=show_code, dt=dt)


def ralston2(f=None, show_code=None, dt=None):
"""Ralston's method for ordinary differential equations.

Ralston's method is a second-order method with two stages and
a minimum local error bound.

It has the characteristics of:

- method stage = 2
- method order = 2
- Butcher Tables:

.. math::

\\begin{array}{c|cc}
0 & 0 & 0 \\\\
2 / 3 & 2 / 3 & 0 \\\\
\\hline & 1 / 4 & 3 / 4
\\end{array}
"""
A = [(), ('2/3',)]
B = [0.25, 0.75]
C = [0, '2/3']
return _base(A=A, B=B, C=C, f=f, show_code=show_code, dt=dt)


def rk2(f=None, show_code=None, dt=None, beta=None):
"""Runge–Kutta methods for ordinary differential equations.

Generic second-order method.

It has the characteristics of:

- method stage = 2
- method order = 2
- Butcher Tables:

.. math::

\\begin{array}{c|cc}
0 & 0 & 0 \\\\
\\beta & \\beta & 0 \\\\
\\hline & 1 - {1 \\over 2 * \\beta} & {1 \over 2 * \\beta}
\\end{array}
"""
beta = 2 / 3 if beta is None else beta
dt = backend.get_dt() if dt is None else dt
show_code = False if show_code is None else show_code

if f is None:
return lambda f: wrapper_of_rk2(f, show_code=show_code, dt=dt, beta=beta)
else:
return wrapper_of_rk2(f, show_code=show_code, dt=dt, beta=beta)


def rk3(f=None, show_code=None, dt=None):
"""Classical third-order Runge-Kutta method for ordinary differential equations.

It has the characteristics of:

- method stage = 3
- method order = 3
- Butcher Tables:

.. math::

\\begin{array}{c|ccc}
0 & 0 & 0 & 0 \\\\
1 / 2 & 1 / 2 & 0 & 0 \\\\
1 & -1 & 2 & 0 \\\\
\\hline & 1 / 6 & 2 / 3 & 1 / 6
\\end{array}

"""
A = [(), (0.5,), (-1, 2)]
B = ['1/6', '2/3', '1/6']
C = [0, 0.5, 1]
return _base(A=A, B=B, C=C, f=f, show_code=show_code, dt=dt)


def heun3(f=None, show_code=None, dt=None):
"""Heun's third-order method for ordinary differential equations.

It has the characteristics of:

- method stage = 3
- method order = 3
- Butcher Tables:

.. math::

\\begin{array}{c|ccc}
0 & 0 & 0 & 0 \\\\
1 / 3 & 1 / 3 & 0 & 0 \\\\
2 / 3 & 0 & 2 / 3 & 0 \\\\
\\hline & 1 / 4 & 0 & 3 / 4
\\end{array}

"""
A = [(), ('1/3',), (0, '2/3')]
B = [0.25, 0, 0.75]
C = [0, '1/3', '2/3']
return _base(A=A, B=B, C=C, f=f, show_code=show_code, dt=dt)


def ralston3(f=None, show_code=None, dt=None):
"""Ralston's third-order method for ordinary differential equations.

It has the characteristics of:

- method stage = 3
- method order = 3
- Butcher Tables:

.. math::

\\begin{array}{c|ccc}
0 & 0 & 0 & 0 \\\\
1 / 2 & 1 / 2 & 0 & 0 \\\\
3 / 4 & 0 & 3 / 4 & 0 \\\\
\\hline & 2 / 9 & 1 / 3 & 4 / 9
\\end{array}

References
----------

.. [1] Ralston, Anthony (1962). "Runge-Kutta Methods with Minimum Error Bounds".
Math. Comput. 16 (80): 431–437. doi:10.1090/S0025-5718-1962-0150954-0

"""
A = [(), (0.5,), (0, 0.75)]
B = ['2/9', '1/3', '4/9']
C = [0, 0.5, 0.75]
return _base(A=A, B=B, C=C, f=f, show_code=show_code, dt=dt)


def ssprk3(f=None, show_code=None, dt=None):
"""Third-order Strong Stability Preserving Runge-Kutta (SSPRK3).

It has the characteristics of:

- method stage = 3
- method order = 3
- Butcher Tables:

.. math::

\\begin{array}{c|ccc}
0 & 0 & 0 & 0 \\\\
1 & 1 & 0 & 0 \\\\
1 / 2 & 1 / 4 & 1 / 4 & 0 \\\\
\\hline & 1 / 6 & 1 / 6 & 2 / 3
\\end{array}

"""
A = [(), (1,), (0.25, 0.25)]
B = ['1/6', '1/6', '2/3']
C = [0, 1, 0.5]
return _base(A=A, B=B, C=C, f=f, show_code=show_code, dt=dt)


def rk4(f=None, show_code=None, dt=None):
"""Classical fourth-order Runge-Kutta method for ordinary differential equations.

It has the characteristics of:

- method stage = 4
- method order = 4
- Butcher Tables:

.. math::

\\begin{array}{c|cccc}
0 & 0 & 0 & 0 & 0 \\\\
1 / 2 & 1 / 2 & 0 & 0 & 0 \\\\
1 / 2 & 0 & 1 / 2 & 0 & 0 \\\\
1 & 0 & 0 & 1 & 0 \\\\
\\hline & 1 / 6 & 1 / 3 & 1 / 3 & 1 / 6
\\end{array}

"""

A = [(), (0.5,), (0., 0.5), (0., 0., 1)]
B = ['1/6', '1/3', '1/3', '1/6']
C = [0, 0.5, 0.5, 1]
return _base(A=A, B=B, C=C, f=f, show_code=show_code, dt=dt)


def ralston4(f=None, show_code=None, dt=None):
"""Ralston's fourth-order method for ordinary differential equations.

It has the characteristics of:

- method stage = 4
- method order = 4
- Butcher Tables:

.. math::

\\begin{array}{c|cccc}
0 & 0 & 0 & 0 & 0 \\\\
.4 & .4 & 0 & 0 & 0 \\\\
.45573725 & .29697761 & .15875964 & 0 & 0 \\\\
1 & .21810040 & -3.05096516 & 3.83286476 & 0 \\\\
\\hline & .17476028 & -.55148066 & 1.20553560 & .17118478
\\end{array}

References
----------

[1] Ralston, Anthony (1962). "Runge-Kutta Methods with Minimum Error Bounds".
Math. Comput. 16 (80): 431–437. doi:10.1090/S0025-5718-1962-0150954-0

"""
A = [(), (.4,), (.29697761, .15875964), (.21810040, -3.05096516, 3.83286476)]
B = [.17476028, -.55148066, 1.20553560, .17118478]
C = [0, .4, .45573725, 1]
return _base(A=A, B=B, C=C, f=f, show_code=show_code, dt=dt)


def rk4_38rule(f=None, show_code=None, dt=None):
"""3/8-rule fourth-order method for ordinary differential equations.

This method doesn't have as much notoriety as the "classical" method,
but is just as classical because it was proposed in the same paper
(Kutta, 1901).

It has the characteristics of:

- method stage = 4
- method order = 4
- Butcher Tables:

.. math::

\\begin{array}{c|cccc}
0 & 0 & 0 & 0 & 0 \\\\
1 / 3 & 1 / 3 & 0 & 0 & 0 \\\\
2 / 3 & -1 / 3 & 1 & 0 & 0 \\\\
1 & 1 & -1 & 1 & 0 \\\\
\\hline & 1 / 8 & 3 / 8 & 3 / 8 & 1 / 8
\\end{array}

"""
A = [(), ('1/3',), ('-1/3', '1'), (1, -1, 1)]
B = ['1/8', '3/8', '3/8', '1/8']
C = [0, '1/3', '2/3', 1]
return _base(A=A, B=B, C=C, f=f, show_code=show_code, dt=dt)

+ 322
- 0
brainpy/integrators/ode/wrapper.py View File

@@ -0,0 +1,322 @@
# -*- coding: utf-8 -*-

from pprint import pprint

from brainpy.integrators import constants
from brainpy.integrators import utils

__all__ = [
'rk_wrapper',
'adaptive_rk_wrapper',
'wrapper_of_rk2',
]

_ODE_UNKNOWN_NO = 0


def _f_names(f):
if f.__name__.isidentifier():
f_name = f.__name__
else:
global _ODE_UNKNOWN_NO
f_name = f'ode_unknown_{_ODE_UNKNOWN_NO}'
_ODE_UNKNOWN_NO += 1
f_new_name = constants.NAME_PREFIX + f_name
return f_new_name


def _step(vars, dt_var, A, C, code_lines, other_args):
# steps
for si, sval in enumerate(A):
# k-step arguments
k_args = []
for v in vars:
k_arg = f'{v}'
for j, sv in enumerate(sval):
if sv not in [0., '0.', '0']:
if sv in ['1.', '1', 1.]:
k_arg += f' + {dt_var} * d{v}_k{j + 1}'
else:
k_arg += f' + {dt_var} * d{v}_k{j + 1} * {sv}'
if k_arg != v:
name = f'k{si + 1}_{v}_arg'
code_lines.append(f' {name} = {k_arg}')
k_args.append(name)
else:
k_args.append(v)

t_arg = 't'
if C[si] not in [0., '0.', '0']:
if C[si] in ['1.', '1', 1.]:
t_arg += f' + {dt_var}'
else:
t_arg += f' + {dt_var} * {C[si]}'
name = f'k{si + 1}_t_arg'
code_lines.append(f' {name} = {t_arg}')
k_args.append(name)
else:
k_args.append(f'{dt_var}')

# k-step derivative names
k_derivatives = [f'd{v}_k{si + 1}' for v in vars]

# k-step code line
code_lines.append(f' {", ".join(k_derivatives)} = f('
f'{", ".join(k_args + other_args[1:])})')


def _update(vars, dt_var, B, code_lines):
return_args = []
for v in vars:
result = v
for i, b1 in enumerate(B):
if b1 not in [0., '0.', '0']:
result += f' + d{v}_k{i + 1} * {dt_var} * {b1}'
code_lines.append(f' {v}_new = {result}')
return_args.append(f'{v}_new')
return return_args


def _compile_and_assign_attrs(code_lines, code_scope, show_code,
func_name, variables, parameters, dt):
# compile
code = '\n'.join(code_lines)
if show_code:
print(code)
print()
pprint(code_scope)
print()
utils.numba_func(code_scope, ['f'])
exec(compile(code, '', 'exec'), code_scope)

# attribute assignment
new_f = code_scope[func_name]
new_f.variables = variables
new_f.parameters = parameters
new_f.origin_f = code_scope['f']
new_f.dt = dt
utils.numba_func(code_scope, func_name)
return code_scope[func_name]


def rk_wrapper(f, show_code, dt, A, B, C):
"""Runge–Kutta methods for ordinary differential equation.

For the system,

.. math::

\frac{d y}{d t}=f(t, y)


Explicit Runge-Kutta methods take the form

.. math::

k_{i}=f\\left(t_{n}+c_{i}h,y_{n}+h\\sum _{j=1}^{s}a_{ij}k_{j}\\right) \\\\
y_{n+1}=y_{n}+h \\sum_{i=1}^{s} b_{i} k_{i}

Each method listed on this page is defined by its Butcher tableau,
which puts the coefficients of the method in a table as follows:

.. math::

\\begin{array}{c|cccc}
c_{1} & a_{11} & a_{12} & \\ldots & a_{1 s} \\\\
c_{2} & a_{21} & a_{22} & \\ldots & a_{2 s} \\\\
\\vdots & \vdots & \vdots & \\ddots & \vdots \\\\
c_{s} & a_{s 1} & a_{s 2} & \\ldots & a_{s s} \\\\
\\hline & b_{1} & b_{2} & \\ldots & b_{s}
\\end{array}

Parameters
----------
f : callable
The derivative function.
show_code : bool
Whether show the formatted code.
dt : float
The numerical precision.
A : tuple, list
The A matrix in the Butcher tableau.
B : tuple, list
The B vector in the Butcher tableau.
C : tuple, list
The C vector in the Butcher tableau.

Returns
-------
integral_func : callable
The one-step numerical integration function.
"""
class_kw, variables, parameters, arguments = utils.get_args(f)
dt_var = 'dt'
func_name = _f_names(f)

# code scope
code_scope = {'f': f, 'dt': dt}

# code lines
code_lines = [f'def {func_name}({", ".join(arguments)}):']

# step stage
_step(variables, dt_var, A, C, code_lines, parameters)

# variable update
return_args = _update(variables, dt_var, B, code_lines)

# returns
code_lines.append(f' return {", ".join(return_args)}')

# compilation
return _compile_and_assign_attrs(
code_lines=code_lines, code_scope=code_scope, show_code=show_code,
func_name=func_name, variables=variables, parameters=parameters, dt=dt)


def adaptive_rk_wrapper(f, dt, A, B1, B2, C, tol, adaptive, show_code, var_type):
"""Adaptive Runge-Kutta numerical method for ordinary differential equations.

The embedded methods are designed to produce an estimate of the local
truncation error of a single Runge-Kutta step, and as result, allow to
control the error with adaptive stepsize. This is done by having two
methods in the tableau, one with order p and one with order :math:`p-1`.

The lower-order step is given by

.. math::

y^*_{n+1} = y_n + h\\sum_{i=1}^s b^*_i k_i,

where the :math:`k_{i}` are the same as for the higher order method. Then the error is

.. math::

e_{n+1} = y_{n+1} - y^*_{n+1} = h\\sum_{i=1}^s (b_i - b^*_i) k_i,


which is :math:`O(h^{p})`. The Butcher Tableau for this kind of method is extended to
give the values of :math:`b_{i}^{*}`

.. math::

\\begin{array}{c|cccc}
c_1 & a_{11} & a_{12}& \\dots & a_{1s}\\\\
c_2 & a_{21} & a_{22}& \\dots & a_{2s}\\\\
\\vdots & \\vdots & \\vdots& \\ddots& \\vdots\\\\
c_s & a_{s1} & a_{s2}& \\dots & a_{ss} \\\\
\\hline & b_1 & b_2 & \\dots & b_s\\\\
& b_1^* & b_2^* & \\dots & b_s^*\\\\
\\end{array}


Parameters
----------
f : callable
The derivative function.
show_code : bool
Whether show the formatted code.
dt : float
The numerical precision.
A : tuple, list
The A matrix in the Butcher tableau.
B1 : tuple, list
The B1 vector in the Butcher tableau.
B2 : tuple, list
The B2 vector in the Butcher tableau.
C : tuple, list
The C vector in the Butcher tableau.
adaptive : bool
tol : float
var_type : str

Returns
-------
integral_func : callable
The one-step numerical integration function.
"""
assert var_type in constants.SUPPORTED_VAR_TYPE, \
f'"var_type" only supports {constants.SUPPORTED_VAR_TYPE}, ' \
f'not {var_type}.'

class_kw, variables, parameters, arguments = utils.get_args(f)
dt_var = 'dt'
func_name = _f_names(f)

if adaptive:
# code scope
code_scope = {'f': f, 'tol': tol}
arguments = list(arguments) + ['dt']
else:
# code scope
code_scope = {'f': f, 'dt': dt}

# code lines
code_lines = [f'def {func_name}({", ".join(arguments)}):']
# stage steps
_step(variables, dt_var, A, C, code_lines, parameters)
# variable update
return_args = _update(variables, dt_var, B1, code_lines)

# error adaptive item
if adaptive:
errors = []
for v in variables:
result = []
for i, (b1, b2) in enumerate(zip(B1, B2)):
if isinstance(b1, str):
b1 = eval(b1)
if isinstance(b2, str):
b2 = eval(b2)
diff = b1 - b2
if diff != 0.:
result.append(f'd{v}_k{i + 1} * {dt_var} * {diff}')
if len(result) > 0:
if var_type == constants.SCALAR_VAR:
code_lines.append(f' {v}_te = abs({" + ".join(result)})')
else:
code_lines.append(f' {v}_te = sum(abs({" + ".join(result)}))')
errors.append(f'{v}_te')
if len(errors) > 0:
code_lines.append(f' error = {" + ".join(errors)}')
code_lines.append(f' if error > tol:')
code_lines.append(f' {dt_var}_new = 0.9 * {dt_var} * (tol / error) ** 0.2')
code_lines.append(f' else:')
code_lines.append(f' {dt_var}_new = {dt_var}')
return_args.append(f'{dt_var}_new')

# returns
code_lines.append(f' return {", ".join(return_args)}')

# compilation
return _compile_and_assign_attrs(
code_lines=code_lines, code_scope=code_scope, show_code=show_code,
func_name=func_name, variables=variables, parameters=parameters, dt=dt)


def wrapper_of_rk2(f, show_code, dt, beta):
class_kw, variables, parameters, arguments = utils.get_args(f)
func_name = _f_names(f)

code_scope = {'f': f, 'dt': dt, 'beta': beta,
'k1': 1 - 1 / (2 * beta), 'k2': 1 / (2 * beta)}
code_lines = [f'def {func_name}({", ".join(arguments)}):']
# k1
k1_args = variables + parameters
k1_vars_d = [f'd{v}_k1' for v in variables]
code_lines.append(f' {", ".join(k1_vars_d)} = f({", ".join(k1_args)})')
# k2
k2_args = [f'{v} + d{v}_k1 * dt * beta' for v in variables]
k2_args.append('t + dt * beta')
k2_args.extend(parameters[1:])
k2_vars_d = [f'd{v}_k2' for v in variables]
code_lines.append(f' {", ".join(k2_vars_d)} = f({", ".join(k2_args)})')
# returns
for v, k1, k2 in zip(variables, k1_vars_d, k2_vars_d):
code_lines.append(f' {v}_new = {v} + ({k1} * k1 + {k2} * k2) * dt')
return_vars = [f'{v}_new' for v in variables]
code_lines.append(f' return {", ".join(return_vars)}')

return _compile_and_assign_attrs(
code_lines=code_lines, code_scope=code_scope, show_code=show_code,
func_name=func_name, variables=variables, parameters=parameters, dt=dt)

+ 11
- 0
brainpy/integrators/sde/__init__.py View File

@@ -0,0 +1,11 @@
# -*- coding: utf-8 -*-

"""
Numerical methods for stochastic differential equations.
"""

from .euler_and_milstein import *
from .srk_scalar import *
# from .srk_strong import *
# from .srk_weak import *


+ 53
- 0
brainpy/integrators/sde/common.py View File

@@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-

from pprint import pprint

from brainpy.integrators import constants
from brainpy.integrators import utils

_SDE_UNKNOWN_NO = 0


def basic_info(f, g):
vdt = 'dt'
if f.__name__.isidentifier():
func_name = f.__name__
elif g.__name__.isidentifier():
func_name = g.__name__
else:
global _SDE_UNKNOWN_NO
func_name = f'unknown_sde{_SDE_UNKNOWN_NO}'
func_new_name = constants.NAME_PREFIX + func_name
class_kw, variables, parameters, arguments = utils.get_args(f)
return vdt, variables, parameters, arguments, func_new_name


def return_compile_and_assign_attrs(code_lines, code_scope, show_code,
variables, parameters, func_name,
sde_type, var_type, wiener_type, dt):
# returns
new_vars = [f'{var}_new' for var in variables]
code_lines.append(f' return {", ".join(new_vars)}')

# compile
code = '\n'.join(code_lines)
if show_code:
print(code)
print()
pprint(code_scope)
print()
utils.numba_func(code_scope, ['f', 'g'])
exec(compile(code, '', 'exec'), code_scope)

# attribute assignment
new_f = code_scope[func_name]
new_f.variables = variables
new_f.parameters = parameters
new_f.origin_f = code_scope['f']
new_f.origin_g = code_scope['g']
new_f.sde_type = sde_type
new_f.var_type = var_type
new_f.wiener_type = wiener_type
new_f.dt = dt
utils.numba_func(code_scope, func_name)
return code_scope[func_name]

+ 276
- 0
brainpy/integrators/sde/euler_and_milstein.py View File

@@ -0,0 +1,276 @@
# -*- coding: utf-8 -*-

from brainpy import backend
from brainpy.integrators import constants
from . import common

__all__ = [
'euler',
'milstein',
]


def _df_and_dg(code_lines, variables, parameters):
# 1. df
# df = f(x, t, *args)
all_df = [f'{var}_df' for var in variables]
code_lines.append(f' {", ".join(all_df)} = f({", ".join(variables + parameters)})')

# 2. dg
# dg = g(x, t, *args)
all_dg = [f'{var}_dg' for var in variables]
code_lines.append(f' {", ".join(all_dg)} = g({", ".join(variables + parameters)})')
code_lines.append(' ')


def _dfdt(code_lines, variables, vdt):
for var in variables:
code_lines.append(f' {var}_dfdt = {var}_df * {vdt}')
code_lines.append(' ')


def _noise_terms(code_lines, variables):
num_vars = len(variables)
if num_vars > 1:
code_lines.append(f' all_dW = backend.normal(0.0, dt_sqrt, ({num_vars},)+backend.shape({variables[0]}_dg))')
for i, var in enumerate(variables):
code_lines.append(f' {var}_dW = all_dW[{i}]')
else:
var = variables[0]
code_lines.append(f' {var}_dW = backend.normal(0.0, dt_sqrt, backend.shape({var}))')
code_lines.append(' ')


# ----------
# Wrapper
# ----------


def _wrap(wrapper, f, g, dt, sde_type, var_type, wiener_type, show_code):
"""The base function to format a SRK method.

Parameters
----------
f : callable
The drift function of the SDE.
g : callable
The diffusion function of the SDE.
dt : float
The numerical precision.
sde_type : str
"utils.ITO_SDE" : Ito's Stochastic Calculus.
"utils.STRA_SDE" : Stratonovich's Stochastic Calculus.
wiener_type : str
var_type : str
"scalar" : with the shape of ().
"population" : with the shape of (N,) or (N1, N2) or (N1, N2, ...).
"system": with the shape of (d, ), (d, N), or (d, N1, N2).
show_code : bool
Whether show the formatted code.

Returns
-------
numerical_func : callable
The numerical function.
"""

sde_type = constants.ITO_SDE if sde_type is None else sde_type
assert sde_type in constants.SUPPORTED_SDE_TYPE, f'Currently, BrainPy only support SDE types: ' \
f'{constants.SUPPORTED_SDE_TYPE}. But we got {sde_type}.'

var_type = constants.POPU_VAR if var_type is None else var_type
assert var_type in constants.SUPPORTED_VAR_TYPE, f'Currently, BrainPy only supports variable types: ' \
f'{constants.SUPPORTED_VAR_TYPE}. But we got {var_type}.'

wiener_type = constants.SCALAR_WIENER if wiener_type is None else wiener_type
assert wiener_type in constants.SUPPORTED_WIENER_TYPE, f'Currently, BrainPy only supports Wiener ' \
f'Process types: {constants.SUPPORTED_WIENER_TYPE}. ' \
f'But we got {wiener_type}.'

show_code = False if show_code is None else show_code
dt = backend.get_dt() if dt is None else dt

if f is not None and g is not None:
return wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type,
var_type=var_type, wiener_type=wiener_type)

elif f is not None:
return lambda g: wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type,
var_type=var_type, wiener_type=wiener_type)

elif g is not None:
return lambda f: wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type,
var_type=var_type, wiener_type=wiener_type)

else:
raise ValueError('Must provide "f" or "g".')


def _euler_wrapper(f, g, dt, sde_type, var_type, wiener_type, show_code):
vdt, variables, parameters, arguments, func_name = common.basic_info(f=f, g=g)

# 1. code scope
code_scope = {'f': f, 'g': g, vdt: dt, f'{vdt}_sqrt': dt ** 0.5, 'backend': backend}

# 2. code lines
code_lines = [f'def {func_name}({", ".join(arguments)}):']

# 2.1 df, dg
_df_and_dg(code_lines, variables, parameters)

# 2.2 dfdt
_dfdt(code_lines, variables, vdt)

# 2.3 dW
_noise_terms(code_lines, variables)

# 2.3 dgdW
# ----
# SCALAR_WIENER : dg * dW
# VECTOR_WIENER : backend.sum(dg * dW, axis=-1)

if wiener_type == constants.SCALAR_WIENER:
for var in variables:
code_lines.append(f' {var}_dgdW = {var}_dg * {var}_dW')
else:
for var in variables:
code_lines.append(f' {var}_dgdW = backend.sum({var}_dg * {var}_dW, axis=-1)')
code_lines.append(' ')

if sde_type == constants.ITO_SDE:
# 2.4 new var
# ----
# y = x + dfdt + dgdW
for var in variables:
code_lines.append(f' {var}_new = {var} + {var}_dfdt + {var}_dgdW')
code_lines.append(' ')

elif sde_type == constants.STRA_SDE:
# 2.4 y_bar = x + backend.sum(dgdW, axis=-1)
all_bar = [f'{var}_bar' for var in variables]
for var in variables:
code_lines.append(f' {var}_bar = {var} + {var}_dgdW')
code_lines.append(' ')

# 2.5 dg_bar = g(y_bar, t, *args)
all_dg_bar = [f'{var}_dg_bar' for var in variables]
code_lines.append(f' {", ".join(all_dg_bar)} = g({", ".join(all_bar + parameters)})')

# 2.6 dgdW2
# ----
# SCALAR_WIENER : dgdW2 = dg_bar * dW
# VECTOR_WIENER : dgdW2 = backend.sum(dg_bar * dW, axis=-1)
if wiener_type == constants.SCALAR_WIENER:
for var in variables:
code_lines.append(f' {var}_dgdW2 = {var}_dg_bar * {var}_dW')
else:
for var in variables:
code_lines.append(f' {var}_dgdW2 = backend.sum({var}_dg_bar * {var}_dW, axis=-1)')
code_lines.append(' ')

# 2.7 new var
# ----
# y = x + dfdt + 0.5 * (dgdW + dgdW2)
for var in variables:
code_lines.append(f' {var}_new = {var} + {var}_dfdt + 0.5 * ({var}_dgdW + {var}_dgdW2)')
code_lines.append(' ')
else:
raise ValueError(f'Unknown SDE type: {sde_type}. We only '
f'supports {constants.SUPPORTED_SDE_TYPE}.')

# return and compile
return common.return_compile_and_assign_attrs(
code_lines=code_lines, code_scope=code_scope, show_code=show_code,
variables=variables, parameters=parameters, func_name=func_name,
sde_type=sde_type, var_type=var_type, wiener_type=wiener_type, dt=dt)


def _milstein_wrapper(f, g, dt, sde_type, var_type, wiener_type, show_code):
vdt, variables, parameters, arguments, func_name = common.basic_info(f=f, g=g)

# 1. code scope
code_scope = {'f': f, 'g': g, vdt: dt, f'{vdt}_sqrt': dt ** 0.5, 'backend': backend}

# 2. code lines
code_lines = [f'def {func_name}({", ".join(arguments)}):']

# 2.1 df, dg
_df_and_dg(code_lines, variables, parameters)

# 2.2 dfdt
_dfdt(code_lines, variables, vdt)

# 2.3 dW
_noise_terms(code_lines, variables)

# 2.3 dgdW
# ----
# dg * dW
for var in variables:
code_lines.append(f' {var}_dgdW = {var}_dg * {var}_dW')
code_lines.append(' ')

# 2.4 df_bar = x + dfdt + backend.sum(dg * dt_sqrt, axis=-1)
all_df_bar = [f'{var}_df_bar' for var in variables]
if wiener_type == constants.SCALAR_WIENER:
for var in variables:
code_lines.append(f' {var}_df_bar = {var} + {var}_dfdt + {var}_dg * {vdt}_sqrt')
else:
for var in variables:
code_lines.append(f' {var}_df_bar = {var} + {var}_dfdt + backend.sum('
f'{var}_dg * {vdt}_sqrt, axis=-1)')

# 2.5 dg_bar = g(y_bar, t, *args)
all_dg_bar = [f'{var}_dg_bar' for var in variables]
code_lines.append(f' {", ".join(all_dg_bar)} = g({", ".join(all_df_bar + parameters)})')
code_lines.append(' ')

# 2.6 dgdW2
# ----
# dgdW2 = 0.5 * (dg_bar - dg) * (dW * dW / dt_sqrt - dt_sqrt)
if sde_type == constants.ITO_SDE:
for var in variables:
code_lines.append(f' {var}_dgdW2 = 0.5 * ({var}_dg_bar - {var}_dg) * '
f'({var}_dW * {var}_dW / {vdt}_sqrt - {vdt}_sqrt)')
elif sde_type == constants.STRA_SDE:
for var in variables:
code_lines.append(f' {var}_dgdW2 = 0.5 * ({var}_dg_bar - {var}_dg) * '
f'{var}_dW * {var}_dW / {vdt}_sqrt')
else:
raise ValueError(f'Unknown SDE type: {sde_type}')
code_lines.append(' ')

# 2.7 new var
# ----
# SCALAR_WIENER : y = x + dfdt + dgdW + dgdW2
# VECTOR_WIENER : y = x + dfdt + backend.sum(dgdW + dgdW2, axis=-1)
if wiener_type == constants.SCALAR_WIENER:
for var in variables:
code_lines.append(f' {var}_new = {var} + {var}_dfdt + {var}_dgdW + {var}_dgdW2')
elif wiener_type == constants.VECTOR_WIENER:
for var in variables:
code_lines.append(f' {var}_new = {var} + {var}_dfdt +backend.sum({var}_dgdW + {var}_dgdW2, axis=-1)')
else:
raise ValueError(f'Unknown Wiener Process : {wiener_type}')
code_lines.append(' ')

# return and compile
return common.return_compile_and_assign_attrs(
code_lines=code_lines, code_scope=code_scope, show_code=show_code,
variables=variables, parameters=parameters, func_name=func_name,
sde_type=sde_type, var_type=var_type, wiener_type=wiener_type, dt=dt)


# ------------------
# Numerical methods
# ------------------


def euler(f=None, g=None, dt=None, sde_type=None, var_type=None, wiener_type=None, show_code=None):
return _wrap(_euler_wrapper, f=f, g=g, dt=dt, sde_type=sde_type, var_type=var_type,
wiener_type=wiener_type, show_code=show_code)


def milstein(f=None, g=None, dt=None, sde_type=None, var_type=None, wiener_type=None, show_code=None):
return _wrap(_milstein_wrapper, f=f, g=g, dt=dt, sde_type=sde_type, var_type=var_type,
wiener_type=wiener_type, show_code=show_code)

+ 218
- 0
brainpy/integrators/sde/exp_euler.py View File

@@ -0,0 +1,218 @@
# -*- coding: utf-8 -*-

import numpy as np
import sympy

from brainpy import backend
from brainpy import errors
from brainpy import tools
from brainpy.integrators import ast_analysis

__all__ = [
'exponential_euler',
]


class Integrator(object):
def __init__(self, diff_eq):
if not isinstance(diff_eq, ast_analysis.DiffEquation):
if diff_eq.__class__.__name__ != 'function':
raise errors.IntegratorError('"diff_eq" must be a function or an instance of DiffEquation .')
else:
diff_eq = ast_analysis.DiffEquation(func=diff_eq)
self.diff_eq = diff_eq
self._update_code = None
self._update_func = None

def __call__(self, y0, t, *args):
return self._update_func(y0, t, *args)

def _compile(self):
# function arguments
func_args = ', '.join([f'_{arg}' for arg in self.diff_eq.func_args])

# function codes
func_code = f'def {self.py_func_name}({func_args}): \n'
func_code += tools.indent(self._update_code + '\n' + f'return _res')
tools.NoiseHandler.normal_pattern.sub(
tools.NoiseHandler.vector_replace_f, func_code)

# function scope
code_scopes = {'numpy': np}
for k_, v_ in self.code_scope.items():
if backend.is_jit() and callable(v_):
v_ = tools.numba_func(v_)
code_scopes[k_] = v_
code_scopes.update(ast_analysis.get_mapping_scope())
code_scopes['_normal_like_'] = backend.normal_like

# function compilation
exec(compile(func_code, '', 'exec'), code_scopes)
func = code_scopes[self.py_func_name]
if backend.is_jit():
func = tools.jit(func)
self._update_func = func

@staticmethod
def get_integral_step(diff_eq, *args):
raise NotImplementedError

@property
def py_func_name(self):
return self.diff_eq.func_name

@property
def update_code(self):
return self._update_code

@property
def update_func(self):
return self._update_func

@property
def code_scope(self):
scope = self.diff_eq.func_scope
if backend.run_on_cpu():
scope['_normal_like_'] = backend.normal_like
return scope


class ExponentialEuler(Integrator):
"""First order, explicit exponential Euler method.

For an ODE equation of the form

.. math::

y^{\\prime}=f(y), \quad y(0)=y_{0}

its schema is given by

.. math::

y_{n+1}= y_{n}+h \\varphi(hA) f (y_{n})

where :math:`A=f^{\prime}(y_{n})` and :math:`\\varphi(z)=\\frac{e^{z}-1}{z}`.

For linear ODE system: :math:`y^{\\prime} = Ay + B`,
the above equation is equal to

.. math::

y_{n+1}= y_{n}e^{hA}-B/A(1-e^{hA})

For a SDE equation of the form

.. math::

d y=(Ay+ F(y))dt + g(y)dW(t) = f(y)dt + g(y)dW(t), \\quad y(0)=y_{0}

its schema is given by [16]_

.. math::

y_{n+1} & =e^{\\Delta t A}(y_{n}+ g(y_n)\\Delta W_{n})+\\varphi(\\Delta t A) F(y_{n}) \\Delta t \\\\
&= y_n + \\Delta t \\varphi(\\Delta t A) f(y) + e^{\\Delta t A}g(y_n)\\Delta W_{n}

where :math:`\\varphi(z)=\\frac{e^{z}-1}{z}`.

Parameters
----------
diff_eq : DiffEquation
The differential equation.

Returns
-------
func : callable
The one-step numerical integrator function.

References
----------
.. [1] Erdoğan, Utku, and Gabriel J. Lord. "A new class of exponential integrators for stochastic
differential equations with multiplicative noise." arXiv preprint arXiv:1608.07096 (2016).
"""

def __init__(self, diff_eq):
super(ExponentialEuler, self).__init__(diff_eq)
self._update_code = self.get_integral_step(diff_eq)
self._compile()

@staticmethod
def get_integral_step(diff_eq, *args):
dt = backend.get_dt()
f_expressions = diff_eq.get_f_expressions(substitute_vars=diff_eq.var_name)

# code lines
code_lines = [str(expr) for expr in f_expressions[:-1]]

# get the linear system using sympy
f_res = f_expressions[-1]
df_expr = ast_analysis.str2sympy(f_res.code).expr.expand()
s_df = sympy.Symbol(f"{f_res.var_name}")
code_lines.append(f'{s_df.name} = {ast_analysis.sympy2str(df_expr)}')
var = sympy.Symbol(diff_eq.var_name, real=True)

# get df part
s_linear = sympy.Symbol(f'_{diff_eq.var_name}_linear')
s_linear_exp = sympy.Symbol(f'_{diff_eq.var_name}_linear_exp')
s_df_part = sympy.Symbol(f'_{diff_eq.var_name}_df_part')
if df_expr.has(var):
# linear
linear = sympy.collect(df_expr, var, evaluate=False)[var]
code_lines.append(f'{s_linear.name} = {ast_analysis.sympy2str(linear)}')
# linear exponential
linear_exp = sympy.exp(linear * dt)
code_lines.append(f'{s_linear_exp.name} = {ast_analysis.sympy2str(linear_exp)}')
# df part
df_part = (s_linear_exp - 1) / s_linear * s_df
code_lines.append(f'{s_df_part.name} = {ast_analysis.sympy2str(df_part)}')

else:
# linear exponential
code_lines.append(f'{s_linear_exp.name} = sqrt({dt})')
# df part
code_lines.append(f'{s_df_part.name} = {ast_analysis.sympy2str(dt * s_df)}')

# get dg part
if diff_eq.is_stochastic:
# dW
noise = f'_normal_like_({diff_eq.var_name})'
code_lines.append(f'_{diff_eq.var_name}_dW = {noise}')
# expressions of the stochastic part
g_expressions = diff_eq.get_g_expressions()
code_lines.extend([str(expr) for expr in g_expressions[:-1]])
g_expr = g_expressions[-1].code
# get the dg_part
s_dg_part = sympy.Symbol(f'_{diff_eq.var_name}_dg_part')
code_lines.append(f'_{diff_eq.var_name}_dg_part = {g_expr} * _{diff_eq.var_name}_dW')
else:
s_dg_part = 0

# update expression
update = var + s_df_part + s_dg_part * s_linear_exp

# The actual update step
code_lines.append(f'{diff_eq.var_name} = {ast_analysis.sympy2str(update)}')
return_expr = ', '.join([diff_eq.var_name] + diff_eq.return_intermediates)
code_lines.append(f'_res = {return_expr}')

# final
code = '\n'.join(code_lines)
subs_dict = {arg: f'_{arg}' for arg in diff_eq.func_args + diff_eq.expr_names}
code = tools.word_replace(code, subs_dict)
return code


def exponential_euler(f):
dt = backend.get_dt()
dt_sqrt = dt ** 0.5

def int_f(x, t, *args):
df, linear_part, g = f(x, t, *args)
dW = backend.normal(0., 1., backend.shape(x))
dg = dt_sqrt * g * dW
exp = backend.exp(linear_part * dt)
y1 = x + (exp - 1) / linear_part * df + exp * dg
return y1

return int_f

+ 442
- 0
brainpy/integrators/sde/srk_scalar.py View File

@@ -0,0 +1,442 @@
# -*- coding: utf-8 -*-

from brainpy import backend
from brainpy.integrators import constants
from . import common

__all__ = [
'srk1w1_scalar',
'srk2w1_scalar',
'KlPl_scalar',
]


# -------
# Helpers
# -------


def _noise_terms(code_lines, variables, vdt, triple_integral=True):
num_vars = len(variables)
if num_vars > 1:
code_lines.append(f' all_I1 = backend.normal(0.0, dt_sqrt, ({num_vars},)+backend.shape({variables[0]}))')
code_lines.append(f' all_I0 = backend.normal(0.0, dt_sqrt, ({num_vars},)+backend.shape({variables[0]}))')
code_lines.append(f' all_I10 = 0.5 * {vdt} * (all_I1 + all_I0 / 3.0 ** 0.5)')
code_lines.append(f' all_I11 = 0.5 * (all_I1 ** 2 - {vdt})')
if triple_integral:
code_lines.append(f' all_I111 = (all_I1 ** 3 - 3 * {vdt} * all_I1) / 6')
code_lines.append(f' ')
for i, var in enumerate(variables):
code_lines.append(f' {var}_I1 = all_I1[{i}]')
code_lines.append(f' {var}_I0 = all_I0[{i}]')
code_lines.append(f' {var}_I10 = all_I10[{i}]')
code_lines.append(f' {var}_I11 = all_I11[{i}]')
if triple_integral:
code_lines.append(f' {var}_I111 = all_I111[{i}]')
code_lines.append(f' ')
else:
var = variables[0]
code_lines.append(f' {var}_I1 = backend.normal(0.0, dt_sqrt, backend.shape({var}))')
code_lines.append(f' {var}_I0 = backend.normal(0.0, dt_sqrt, backend.shape({var}))')
code_lines.append(f' {var}_I10 = 0.5 * {vdt} * ({var}_I1 + {var}_I0 / 3.0 ** 0.5)')
code_lines.append(f' {var}_I11 = 0.5 * ({var}_I1 ** 2 - {vdt})')
if triple_integral:
code_lines.append(f' {var}_I111 = ({var}_I1 ** 3 - 3 * {vdt} * {var}_I1) / 6')
code_lines.append(' ')


def _state1(code_lines, variables, parameters):
f_names = [f'{var}_f_H0s1' for var in variables]
g_names = [f'{var}_g_H1s1' for var in variables]
code_lines.append(f' {", ".join(f_names)} = f({", ".join(variables + parameters)})')
code_lines.append(f' {", ".join(g_names)} = g({", ".join(variables + parameters)})')
code_lines.append(' ')


# ---------
# Wrappers
# ---------


def _srk1w1_wrapper(f, g, dt, show_code, sde_type, var_type, wiener_type):
vdt, variables, parameters, arguments, func_name = common.basic_info(f=f, g=g)

# 1. code scope
code_scope = {'f': f, 'g': g, vdt: dt, f'{vdt}_sqrt': dt ** 0.5, 'backend': backend}

# 2. code lines
code_lines = [f'def {func_name}({", ".join(arguments)}):']

# 2.1 noise
_noise_terms(code_lines, variables, vdt, triple_integral=True)

# 2.2 stage 1
_state1(code_lines, variables, parameters)

# 2.3 stage 2
all_H0s2, all_H1s2 = [], []
for var in variables:
code_lines.append(f' {var}_H0s2 = {var} + {vdt} * 0.75 * {var}_f_H0s1 + '
f'1.5 * {var}_g_H1s1 * {var}_I10 / {vdt}')
all_H0s2.append(f'{var}_H0s2')
code_lines.append(f' {var}_H1s2 = {var} + {vdt} * 0.25 * {var}_f_H0s1 + '
f'dt_sqrt * 0.5 * {var}_g_H1s1')
all_H1s2.append(f'{var}_H1s2')
all_H0s2.append(f't + 0.75 * {vdt}') # t
all_H1s2.append(f't + 0.25 * {vdt}') # t
f_names = [f'{var}_f_H0s2' for var in variables]
code_lines.append(f' {", ".join(f_names)} = f({", ".join(all_H0s2 + parameters[1:])})')
g_names = [f'{var}_g_H1s2' for var in variables]
code_lines.append(f' {", ".join(g_names)} = g({", ".join(all_H1s2 + parameters[1:])})')
code_lines.append(' ')

# 2.4 state 3
all_H1s3 = []
for var in variables:
code_lines.append(f' {var}_H1s3 = {var} + {vdt} * {var}_f_H0s1 - dt_sqrt * {var}_g_H1s1')
all_H1s3.append(f'{var}_H1s3')
all_H1s3.append(f't + {vdt}') # t
g_names = [f'{var}_g_H1s3' for var in variables]
code_lines.append(f' {", ".join(g_names)} = g({", ".join(all_H1s3 + parameters[1:])})')
code_lines.append(' ')

# 2.5 state 4
all_H1s4 = []
for var in variables:
code_lines.append(f' {var}_H1s4 = {var} + 0.25 * {vdt} * {var}_f_H0s1 + dt_sqrt * '
f'(-5 * {var}_g_H1s1 + 3 * {var}_g_H1s2 + 0.5 * {var}_g_H1s3)')
all_H1s4.append(f'{var}_H1s4')
all_H1s4.append(f't + 0.25 * {vdt}') # t
g_names = [f'{var}_g_H1s4' for var in variables]
code_lines.append(f' {", ".join(g_names)} = g({", ".join(all_H1s4 + parameters[1:])})')
code_lines.append(' ')

# 2.6 final stage
for var in variables:
code_lines.append(f' {var}_f1 = {var}_f_H0s1/3 + {var}_f_H0s2 * 2/3')
code_lines.append(f' {var}_g1 = -{var}_I1 - {var}_I11/dt_sqrt + 2 * {var}_I10/{vdt} - 2 * {var}_I111/{vdt}')
code_lines.append(f' {var}_g2 = {var}_I1 * 4/3 + {var}_I11 / dt_sqrt * 4/3 - '
f'{var}_I10 / {vdt} * 4/3 + {var}_I111 / {vdt} * 5/3')
code_lines.append(f' {var}_g3 = {var}_I1 * 2/3 - {var}_I11/dt_sqrt/3 - '
f'{var}_I10 / {vdt} * 2/3 - {var}_I111 / {vdt} * 2/3')
code_lines.append(f' {var}_g4 = {var}_I111 / {vdt}')
code_lines.append(f' {var}_new = {var} + {vdt} * {var}_f1 + {var}_g1 * {var}_g_H1s1 + '
f'{var}_g2 * {var}_g_H1s2 + {var}_g3 * {var}_g_H1s3 + {var}_g4 * {var}_g_H1s4')
code_lines.append(' ')

# return and compile
return common.return_compile_and_assign_attrs(
code_lines=code_lines, code_scope=code_scope, show_code=show_code,
variables=variables, parameters=parameters, func_name=func_name,
sde_type=sde_type, var_type=var_type, wiener_type=wiener_type, dt=dt)


def _srk2w1_wrapper(f, g, dt, show_code, sde_type, var_type, wiener_type):
vdt, variables, parameters, arguments, func_name = common.basic_info(f=f, g=g)

# 1. code scope
code_scope = {'f': f, 'g': g, vdt: dt, f'{vdt}_sqrt': dt ** 0.5, 'backend': backend}

# 2. code lines
code_lines = [f'def {func_name}({", ".join(arguments)}):']

# 2.1 noise
_noise_terms(code_lines, variables, vdt, triple_integral=True)

# 2.2 stage 1
_state1(code_lines, variables, parameters)

# 2.3 stage 2
# ----
# H0s2 = x + dt * f_H0s1
# H1s2 = x + dt * 0.25 * f_H0s1 - dt_sqrt * 0.5 * g_H1s1
# f_H0s2 = f(H0s2, t + dt, *args)
# g_H1s2 = g(H1s2, t + 0.25 * dt, *args)
all_H0s2, all_H1s2 = [], []
for var in variables:
code_lines.append(f' {var}_H0s2 = {var} + {vdt} * {var}_f_H0s1')
all_H0s2.append(f'{var}_H0s2')
code_lines.append(f' {var}_H1s2 = {var} + {vdt} * 0.25 * {var}_f_H0s1 - '
f'dt_sqrt * 0.5 * {var}_g_H1s1')
all_H1s2.append(f'{var}_H1s2')
all_H0s2.append(f't + {vdt}') # t
all_H1s2.append(f't + 0.25 * {vdt}') # t
f_names = [f'{var}_f_H0s2' for var in variables]
code_lines.append(f' {", ".join(f_names)} = f({", ".join(all_H0s2 + parameters[1:])})')
g_names = [f'{var}_g_H1s2' for var in variables]
code_lines.append(f' {", ".join(g_names)} = g({", ".join(all_H1s2 + parameters[1:])})')
code_lines.append(' ')

# 2.4 state 3
# ---
# H0s3 = x + dt * (0.25 * f_H0s1 + 0.25 * f_H0s2) + (g_H1s1 + 0.5 * g_H1s2) * I10 / dt
# H1s3 = x + dt * f_H0s1 + dt_sqrt * g_H1s1
# f_H0s3 = g(H0s3, t + 0.5 * dt, *args)
# g_H1s3 = g(H1s3, t + dt, *args)
all_H0s3, all_H1s3 = [], []
for var in variables:
code_lines.append(f' {var}_H0s3 = {var} + {vdt} * (0.25 * {var}_f_H0s1 + 0.25 * {var}_f_H0s2) + '
f'({var}_g_H1s1 + 0.5 * {var}_g_H1s2) * {var}_I10 / {vdt}')
all_H0s3.append(f'{var}_H0s3')
code_lines.append(f' {var}_H1s3 = {var} + {vdt} * {var}_f_H0s1 + dt_sqrt * {var}_g_H1s1')
all_H1s3.append(f'{var}_H1s3')
all_H0s3.append(f't + 0.5 * {vdt}') # t
all_H1s3.append(f't + {vdt}') # t
f_names = [f'{var}_f_H0s3' for var in variables]
g_names = [f'{var}_g_H1s3' for var in variables]
code_lines.append(f' {", ".join(f_names)} = f({", ".join(all_H0s3 + parameters[1:])})')
code_lines.append(f' {", ".join(g_names)} = g({", ".join(all_H1s3 + parameters[1:])})')
code_lines.append(' ')

# 2.5 state 4
# ----
# H1s4 = x + dt * 0.25 * f_H0s3 + dt_sqrt * (2 * g_H1s1 - g_H1s2 + 0.5 * g_H1s3)
# g_H1s4 = g(H1s4, t + 0.25 * dt, *args)
all_H1s4 = []
for var in variables:
code_lines.append(f' {var}_H1s4 = {var} + 0.25 * {vdt} * {var}_f_H0s1 + dt_sqrt * '
f'(2 * {var}_g_H1s1 - {var}_g_H1s2 + 0.5 * {var}_g_H1s3)')
all_H1s4.append(f'{var}_H1s4')
all_H1s4.append(f't + 0.25 * {vdt}') # t
g_names = [f'{var}_g_H1s4' for var in variables]
code_lines.append(f' {", ".join(g_names)} = g({", ".join(all_H1s4 + parameters[1:])})')
code_lines.append(' ')

# 2.6 final stage
# ----
# f1 = f_H0s1 / 6 + f_H0s2 / 6 + f_H0s3 * 2 / 3
# g1 = - I1 + I11 / dt_sqrt + 2 * I10 / dt - 2 * I111 / dt
# g2 = I1 * 4 / 3 - I11 / dt_sqrt * 4 / 3 - I10 / dt * 4 / 3 + I111 / dt * 5 / 3
# g3 = I1 * 2 / 3 + I11 / dt_sqrt / 3 - I10 / dt * 2 / 3 - I111 / dt * 2 / 3
# g4 = I111 / dt
# y1 = x + dt * f1 + g1 * g_H1s1 + g2 * g_H1s2 + g3 * g_H1s3 + g4 * g_H1s4
for var in variables:
code_lines.append(f' {var}_f1 = {var}_f_H0s1/6 + {var}_f_H0s2/6 + {var}_f_H0s3*2/3')
code_lines.append(f' {var}_g1 = -{var}_I1 + {var}_I11/dt_sqrt + 2 * {var}_I10/{vdt} - 2 * {var}_I111/{vdt}')
code_lines.append(f' {var}_g2 = {var}_I1 * 4/3 - {var}_I11 / dt_sqrt * 4/3 - '
f'{var}_I10 / {vdt} * 4/3 + {var}_I111 / {vdt} * 5/3')
code_lines.append(f' {var}_g3 = {var}_I1 * 2/3 + {var}_I11/dt_sqrt/3 - '
f'{var}_I10 / {vdt} * 2/3 - {var}_I111 / {vdt} * 2/3')
code_lines.append(f' {var}_g4 = {var}_I111 / {vdt}')
code_lines.append(f' {var}_new = {var} + {vdt} * {var}_f1 + {var}_g1 * {var}_g_H1s1 + '
f'{var}_g2 * {var}_g_H1s2 + {var}_g3 * {var}_g_H1s3 + {var}_g4 * {var}_g_H1s4')
code_lines.append(' ')

# return and compile
return common.return_compile_and_assign_attrs(
code_lines=code_lines, code_scope=code_scope, show_code=show_code,
variables=variables, parameters=parameters, func_name=func_name,
sde_type=sde_type, var_type=var_type, wiener_type=wiener_type, dt=dt)


def _KlPl_wrapper(f, g, dt, show_code, sde_type, var_type, wiener_type):
vdt, variables, parameters, arguments, func_name = common.basic_info(f=f, g=g)

# 1. code scope
code_scope = {'f': f, 'g': g, vdt: dt, f'{vdt}_sqrt': dt ** 0.5, 'backend': backend}

# 2. code lines
code_lines = [f'def {func_name}({", ".join(arguments)}):']

# 2.1 noise
_noise_terms(code_lines, variables, vdt, triple_integral=False)

# 2.2 stage 1
_state1(code_lines, variables, parameters)

# 2.3 stage 2
# ----
# H1s2 = x + dt * f_H0s1 + dt_sqrt * g_H1s1
# g_H1s2 = g(H1s2, t0, *args)
all_H1s2 = []
for var in variables:
code_lines.append(f' {var}_H1s2 = {var} + {vdt} * {var}_f_H0s1 + dt_sqrt * {var}_g_H1s1')
all_H1s2.append(f'{var}_H1s2')
g_names = [f'{var}_g_H1s2' for var in variables]
code_lines.append(f' {", ".join(g_names)} = g({", ".join(all_H1s2 + parameters)})')
code_lines.append(' ')

# 2.4 final stage
# ----
# g1 = (I1 - I11 / dt_sqrt + I10 / dt)
# g2 = I11 / dt_sqrt
# y1 = x + dt * f_H0s1 + g1 * g_H1s1 + g2 * g_H1s2
for var in variables:
code_lines.append(f' {var}_g1 = -{var}_I1 + {var}_I11/dt_sqrt + {var}_I10/{vdt}')
code_lines.append(f' {var}_g2 = {var}_I11 / dt_sqrt')
code_lines.append(f' {var}_new = {var} + {vdt} * {var}_f_H0s1 + '
f'{var}_g1 * {var}_g_H1s1 + {var}_g2 * {var}_g_H1s2')
code_lines.append(' ')

# return and compile
return common.return_compile_and_assign_attrs(
code_lines=code_lines, code_scope=code_scope, show_code=show_code,
variables=variables, parameters=parameters, func_name=func_name,
sde_type=sde_type, var_type=var_type, wiener_type=wiener_type, dt=dt)


def _wrap(wrapper, f, g, dt, sde_type, var_type, wiener_type, show_code):
"""The base function to format a SRK method.

Parameters
----------
f : callable
The drift function of the SDE.
g : callable
The diffusion function of the SDE.
dt : float
The numerical precision.
sde_type : str
"utils.ITO_SDE" : Ito's Stochastic Calculus.
"utils.STRA_SDE" : Stratonovich's Stochastic Calculus.
wiener_type : str
var_type : str
"scalar" : with the shape of ().
"population" : with the shape of (N,) or (N1, N2) or (N1, N2, ...).
"system": with the shape of (d, ), (d, N), or (d, N1, N2).
show_code : bool
Whether show the formatted code.

Returns
-------
numerical_func : callable
The numerical function.
"""

var_type = constants.POPU_VAR if var_type is None else var_type
assert var_type in constants.SUPPORTED_VAR_TYPE, f'Currently, BrainPy only supports variable types: ' \
f'{constants.SUPPORTED_VAR_TYPE}. But we got {var_type}.'

sde_type = constants.ITO_SDE if sde_type is None else sde_type
assert sde_type == constants.ITO_SDE, 'SRK method for SDEs with scalar noise only supports Ito SDE type.'

assert wiener_type == constants.SCALAR_WIENER, 'SRK method for SDEs with scalar noise only supports ' \
'scalar Wiener Process.'

show_code = False if show_code is None else show_code
dt = backend.get_dt() if dt is None else dt

if f is not None and g is not None:
return wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type,
var_type=var_type, wiener_type=wiener_type)

elif f is not None:
return lambda g: wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type,
var_type=var_type, wiener_type=wiener_type)

elif g is not None:
return lambda f: wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type,
var_type=var_type, wiener_type=wiener_type)

else:
raise ValueError('Must provide "f" or "g".')


# -------------------
# Numerical functions
# -------------------


def srk1w1_scalar(f=None, g=None, dt=None, sde_type=None, var_type=None, wiener_type=None, show_code=None):
"""Order 2.0 weak SRK methods for SDEs with scalar Wiener process.

This method has have strong orders :math:`(p_d, p_s) = (2.0,1.5)`.

The Butcher table is:

.. math::

\\begin{array}{l|llll|llll|llll}
0 &&&&& &&&& &&&& \\\\
3/4 &3/4&&&& 3/2&&& &&&& \\\\
0 &0&0&0&& 0&0&0&& &&&&\\\\
\\hline
0 \\\\
1/4 & 1/4&&& & 1/2&&&\\\\
1 & 1&0&&& -1&0&\\\\
1/4& 0&0&1/4&& -5&3&1/2\\\\
\\hline
& 1/3& 2/3& 0 & 0 & -1 & 4/3 & 2/3&0 & -1 &4/3 &-1/3 &0 \\\\
\\hline
& &&&& 2 &-4/3 & -2/3 & 0 & -2 & 5/3 & -2/3 & 1
\\end{array}


References
----------

.. [1] Rößler, Andreas. "Strong and weak approximation methods for stochastic differential
equations—some recent developments." Recent developments in applied probability and
statistics. Physica-Verlag HD, 2010. 127-153.
.. [2] Rößler, Andreas. "Runge–Kutta methods for the strong approximation of solutions of
stochastic differential equations." SIAM Journal on Numerical Analysis 48.3
(2010): 922-952.

"""
return _wrap(_srk1w1_wrapper, f=f, g=g, dt=dt, sde_type=sde_type, var_type=var_type,
wiener_type=wiener_type, show_code=show_code)


def srk2w1_scalar(f=None, g=None, dt=None, sde_type=None, var_type=None, wiener_type=None, show_code=None):
"""Order 1.5 Strong SRK Methods for SDEs witdt Scalar Noise.

This method has have strong orders :math:`(p_d, p_s) = (3.0,1.5)`.

The Butcher table is:

.. math::

\\begin{array}{c|cccc|cccc|ccc|}
0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & & & & \\\\
1 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & & & & \\\\
1 / 2 & 1 / 4 & 1 / 4 & 0 & 0 & 1 & 1 / 2 & 0 & 0 & & & & \\\\
0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & & & & \\\\
\\hline 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & & & & \\\\
1 / 4 & 1 / 4 & 0 & 0 & 0 & -1 / 2 & 0 & 0 & 0 & & & & \\\\
1 & 1 & 0 & 0 & 0 & 1 & 0 & 0 & 0 & & & & \\\\
1 / 4 & 0 & 0 & 1 / 4 & 0 & 2 & -1 & 1 / 2 & 0 & & & & \\\\
\\hline & 1 / 6 & 1 / 6 & 2 / 3 & 0 & -1 & 4 / 3 & 2 / 3 & 0 & -1 & -4 / 3 & 1 / 3 & 0 \\\\
\\hline & & & & &2 & -4 / 3 & -2 / 3 & 0 & -2 & 5 / 3 & -2 / 3 & 1
\\end{array}


References
----------

[1] Rößler, Andreas. "Strong and weak approximation methods for stochastic differential
equations—some recent developments." Recent developments in applied probability and
statistics. Physica-Verlag HD, 2010. 127-153.
[2] Rößler, Andreas. "Runge–Kutta methods for the strong approximation of solutions of
stochastic differential equations." SIAM Journal on Numerical Analysis 48.3
(2010): 922-952.
"""
return _wrap(_srk2w1_wrapper, f=f, g=g, dt=dt, sde_type=sde_type, var_type=var_type,
wiener_type=wiener_type, show_code=show_code)


def KlPl_scalar(f=None, g=None, dt=None, sde_type=None, var_type=None, wiener_type=None, show_code=None):
"""Order 1.0 Strong SRK Methods for SDEs with Scalar Noise.

This method has have orders :math:`p_s = 1.0`.

The Butcher table is:

.. math::

\\begin{array}{c|cc|cc|cc|c}
0 & 0 & 0 & 0 & 0 & & \\\\
0 & 0 & 0 & 0 & 0 & & \\\\
\\hline 0 & 0 & 0 & 0 & 0 & & \\\\
0 & 1 & 0 & 1 & 0 & & \\\\
\\hline 0 & 1 & 0 & 1 & 0 & -1 & 1 \\\\
\\hline & & & 1 & 0 & 0 & 0
\\end{array}

References
----------

[1] P. E. Kloeden, E. Platen, Numerical Solution of Stochastic Differential
Equations, 2nd Edition, Springer, Berlin Heidelberg New York, 1995.
"""
return _wrap(_KlPl_wrapper, f=f, g=g, dt=dt, sde_type=sde_type, var_type=var_type,
wiener_type=wiener_type, show_code=show_code)

+ 442
- 0
brainpy/integrators/sde/srk_strong.py View File

@@ -0,0 +1,442 @@
# -*- coding: utf-8 -*-

from brainpy import backend
from brainpy.integrators import constants
from . import common

__all__ = [
'srk1_strong',
]


def _vector_wiener_terms(code_lines, sde_type, vdt, shape_D, shape_m):
if sde_type == constants.ITO_SDE:
I2 = f'0.5*(_term3 - {vdt} * backend.eye({shape_m})) + _a*0.5*{vdt}/math.pi'
elif sde_type == constants.STRA_SDE:
I2 = f'0.5*_term3 + _a*0.5*dt/math.pi'
else:
raise ValueError(f'Unknown SDE type: {sde_type}. We only supports {constants.SUPPORTED_SDE_TYPE}.')

if shape_D:
shape_D = shape_D + '+'

noise_string = f'''
# Noise Terms #
# ----------- #
# single Ito integrals
_I1 = backend.normal(0., {vdt}_sqrt, {shape_D}({shape_m},))
# double Ito integrals
_h = (2.0 / {vdt}) ** 0.5)
_a = backend.zeros(shape={shape_D}({shape_m}, {shape_m}))
for _k in range(1, num_iter + 1):
_x = backend.normal(loc=0., scale=1., size={shape_D}({shape_m}, 1))
_y = backend.normal(loc=0., scale=1., size={shape_D}(1, {shape_m})) + _h * _I1
_term1 = backend.matmul(_x, _y)
_term2 = backend.matmul(backend.reshape(_y, {shape_D}({shape_m}, 1)),
backend.reshape(_x, {shape_D}(1, {shape_m})))
_a += (_term1 - _term2) / _k
_I1_rs = backend.reshape(_I1, {shape_D}({shape_m}, 1))
_term3 = backend.matmul(_I1_rs, backend.reshape(_I1, {shape_D}(1, {shape_m})))
_I2 = {I2}
'''
noise_lines = noise_string.split('\n')
code_lines.extend(noise_lines)


# ----------
# Wrapper
# ----------


def _srk2_pop_var_vector_wiener(sde_type, code_lines, variables, parameters, vdt):
# shape information
# -----
all_f = [f'f_{var}' for var in variables]
all_g = [f'g_{var}' for var in variables]
noise_string = f'''
{", ".join(all_f)} = f({", ".join(variables + parameters)}) # shape = (..)
{", ".join(all_g)} = g({", ".join(variables + parameters)}) # shape = (.., m)
noise_shape = backend.shape(g_x1)
_D = noise_shape[:-1]
_m = noise_shape[-1]
'''
code_lines.extend(noise_string.split("\n"))

# noise terms
_vector_wiener_terms(code_lines, sde_type, vdt, shape_D='_D', shape_m='_m')

# numerical integration
# step 1
# ---
# g_x1_rs = backend.reshape(g_x1, _D + (1, _m))
# g_x2_rs = backend.reshape(g_x2, _D + (1, _m))
for var in variables:
code_lines.append(f" g_{var}_rs = backend.reshape(g_{var}, _D+(1, _m))")
# step 2
# ---
# g_H1_x1 = backend.reshape(backend.matmul(g_x1_rs, _I2) / dt_sqrt, _D + (_m,))
# g_H1_x2 = backend.reshape(backend.matmul(g_x2_rs, _I2) / dt_sqrt, _D + (_m,))
for var in variables:
code_lines.append(f' g_H1_{var} = backend.reshape(backend.matmul(g_{var}_rs, _I2) / {vdt}_sqrt, _D + (_m,))')
# step 3
# ---
# x1_rs = backend.reshape(x1, _D + (1,))
# x2_rs = backend.reshape(x2, _D + (1,))
for var in variables:
code_lines.append(f' {var}_rs = backend.reshape({var}, _D + (1,))')
# step 4
# ---
# H2_x1 = x1_rs + g_H1_x1
# H3_x1 = x1_rs - g_H1_x1
for var in variables:
code_lines.append(f' H2_{var} = {var}_rs + g_H1_{var}')
code_lines.append(f' H3_{var} = {var}_rs - g_H1_{var}')
code_lines.append(' ')
# step 5
# ---
# _g_x1 = backend.matmul(g_x1_rs, _I1_rs)
for var in variables:
code_lines.append(f' _g_{var} = backend.matmul(g_{var}_rs, _I1_rs)')
# step 6
# ----
# x1_new = x1 + f_x1 + _g_x1[..., 0, 0]
for var in variables:
code_lines.append(f' {var}_new = {var} + f_{var} + _g_{var}[..., 0, 0]')
# for _k in range(_m):
code_lines.append('for _k in range(_m):')
# g_x1_H2, g_x2_H2 = g(H2_x1[..., _k], H2_x2[..., _k], t, *args)
all_H2 = [f'H2_{var}[..., _k]' for var in variables]
all_g_H2 = [f'g_{var}_H2' for var in variables]
code_lines.append(f' {", ".join(all_g_H2)} = g({", ".join(all_H2 + parameters)})')
# g_x1_H3, g_x2_H3 = g(H3_x1[..., _k], H3_x2[..., _k], t, *args)
all_H3 = [f'H3_{var}[..., _k]' for var in variables]
all_g_H3 = [f'g_{var}_H3' for var in variables]
code_lines.append(f' {", ".join(all_g_H3)} = g({", ".join(all_H3 + parameters)})')
# x1_new += 0.5 * dt_sqrt * (g_x1_H2[..., _k] - g_x1_H3[..., _k])
# x2_new += 0.5 * dt_sqrt * (g_x2_H2[..., _k] - g_x2_H3[..., _k])
for var in variables:
code_lines.append(f' {var}_new += 0.5 * {vdt}_sqrt * (g_{var}_H2[..., _k] - g_{var}_H3[..., _k])')


def _srk2_pop_or_scalar_var_scalar_wiener(sde_type, code_lines, variables, parameters, vdt):
if sde_type == constants.ITO_SDE:
I2 = f'0.5 * (_I1 * _I1 - {vdt})'
elif sde_type == constants.STRA_SDE:
I2 = f'0.5 * _I1 * _I1'
else:
raise ValueError(f'Unknown SDE type: {sde_type}. We only supports {constants.SUPPORTED_SDE_TYPE}.')

# shape info
# -----
all_f = [f'f_{var}' for var in variables]
all_g = [f'g_{var}' for var in variables]

code_string = f'''
{", ".join(all_f)} = f({", ".join(variables + parameters)}) # shape = (..)
{", ".join(all_g)} = g({", ".join(variables + parameters)}) # shape = (..)

# single Ito integrals
_I1 = backend.normal(0., {vdt}_sqrt, backend.shape({variables[0]})) # shape = (..)
# double Ito integrals
_I2 = {I2} # shape = (..)
'''
code_splits = code_string.split('\n')
code_lines.extend(code_splits)

# numerical integration
# -----
# H1
for var in variables:
code_lines.append(f' g_H1_{var} = g_{var} * _I2 / {vdt}_sqrt # shape (.., )')
# H2
all_H2 = [f'H2_{var}' for var in variables]
for var in variables:
code_lines.append(f' H2_{var} = {var} + g_H1_{var} # shape (.., )')
all_g_H2 = [f'g_{var}_H2' for var in variables]
code_lines.append(f' {", ".join(all_g_H2)} = g({", ".join(all_H2 + parameters)})')
code_lines.append(f' ')
# H3
all_H3 = [f'H3_{var}' for var in variables]
for var in variables:
code_lines.append(f' H3_{var} = {var} - g_H1_{var} # shape (.., )')
all_g_H3 = [f'g_{var}_H3' for var in variables]
code_lines.append(f' {", ".join(all_g_H3)} = g({", ".join(all_H3 + parameters)})')
code_lines.append(f' ')
# final results
for var in variables:
code_lines.append(f' {var}_new = {var} + f_{var} + g_{var} * _I1 '
f'+ 0.5 * {vdt}_sqrt * (g_{var}_H2 - g_{var}_H3)')


def _srk1_scalar_var_with_vector_wiener(sde_type, code_lines, variables, parameters, vdt):
# shape information
all_f = [f'f_{var}' for var in variables]
all_g = [f'g_{var}' for var in variables]
code1 = f'''
# shape info #
# ---------- #

{", ".join(all_f)} = f({", ".join(variables + parameters)}) # shape = ()
{", ".join(all_g)} = g({", ".join(variables + parameters)}) # shape = (m)
noise_shape = backend.shape(g_x1)
_m = noise_shape[0]
'''
code_lines.extend(code1.split('\n'))

# noise term
_vector_wiener_terms(code_lines, sde_type, vdt, shape_D='', shape_m='_m')

# numerical integration

# p1
# ---
# g_x1_rs = backend.reshape(g_x1, (1, _m))
# g_x2_rs = backend.reshape(g_x2, (1, _m))
for var in variables:
code_lines.append(f' g_{var}_rs = backend.reshape(g_{var}, (1, _m))')

# p2
# ---
# g_H1_x1 = backend.matmul(g_x1_rs, _I2) / dt_sqrt # shape (1, m)
# g_H1_x2 = backend.matmul(g_x2_rs, _I2) / dt_sqrt # shape (1, m)
for var in variables:
code_lines.append(f' g_H1_{var} = backend.matmul(g_{var}_rs, _I2) / {vdt}_sqrt # shape (1, m)')

# p3
# ---
# H2_x1 = x1 + g_H1_x1[0] # shape (m)
# H3_x1 = x1 - g_H1_x1[0] # shape (m)
for var in variables:
code_lines.append(f' H2_{var} = {var} + g_H1_{var}[0] # shape (m)')
code_lines.append(' ')

# p4
# ---
# g1_x1 = backend.matmul(g_x1_rs, _I1_rs) # shape (1, 1)
# x1_new = x1 + f_x1 + g1_x1[0, 0] # shape ()
for var in variables:
code_lines.append(f' g1_{var} = backend.matmul(g_{var}_rs, _I1_rs) # shape (1, 1)')
code_lines.append(f' {var}_new = {var} + f_{var} + g1_{var}[0, 0] # shape ()')

# p5
# ---
# for _k in range(_m):
# g_x1_H2, g_x2_H2 = g(H2_x1[_k], H2_x2[_k], t, *args)
# g_x1_H3, g_x2_H3 = g(H3_x1[_k], H3_x2[_k], t, *args)
# x1_new += 0.5 * dt_sqrt * (g_x1_H2[_k] - g_x1_H3[_k])
# x2_new += 0.5 * dt_sqrt * (g_x2_H2[_k] - g_x2_H3[_k])
code_lines.append(' for _k in range(_m):')
all_h2_k = [f'H2_{var}[_k]' for var in variables]
all_g_h2 = [f'g_{var}_H2' for var in variables]
code_lines.append(f' {", ".join(all_g_h2)} = g({", ".join(all_h2_k + parameters)})')
all_h3_k = [f'H3_{var}[_k]' for var in variables]
all_g_h3 = [f'g_{var}_H3' for var in variables]
code_lines.append(f' {", ".join(all_g_h3)} = g({", ".join(all_h3_k + parameters)})')
for var in variables:
code_lines.append(f' {var}_new += 0.5 * {vdt}_sqrt * (g_{var}_H2[_k] - g_{var}_H3[_k])')


def _srk1_system_var_with_vector_wiener(sde_type, code_lines, variables, parameters, vdt):
# shape information
code1 = f'''
# shape infor #
# ----------- #
f_x = f({", ".join(variables + parameters)}) # shape = (d, ..)
g_x = g({", ".join(variables + parameters)}) # shape = (d, .., m)
_shape = backend.shape(g_x)
_d = _shape[0]
_m = _shape[-1]
_D = _shape[1:-1]
'''
code_lines.extend(code1.split('\n'))

# noise term
_vector_wiener_terms(code_lines, sde_type, vdt, shape_D='_D', shape_m='_m')

# numerical integration
code2 = f'''
# numerical integration #
# --------------------- #
g_x2 = backend.moveaxis(g_x, 0, -2) # shape = (.., d, m)
g_H1_k = backend.matmul(g_x2, _I2) / dt_sqrt # shape (.., d, m)
g_H1_k = backend.moveaxis(g_H1_k, -2, 0) # shape (d, .., m)
x_rs = backend.reshape(x, (_d,) + _D + (1,))
H2 = x_rs + g_H1_k # shape (d, .., m)
H3 = x_rs - g_H1_k # shape (d, .., m)
g1 = backend.matmul(g_x2, _I1_rs) # shape (.., d, 1)
g1 = backend.moveaxis(g1, -2, 0) # shape (d, .., 1)
y = x + f_x + g1[..., 0] # shape (d, ..)
for _k in range(_m):
y += 0.5 * dt_sqrt * g(H2[..., _k], t, *args)[..., _k]
y -= 0.5 * dt_sqrt * g(H3[..., _k], t, *args)[..., _k]
'''
code_lines.extend(code2.split('\n'))


def _srk1_system_var_with_scalar_wiener(sde_type, code_lines, variables, parameters, vdt):
if sde_type == constants.ITO_SDE:
I2 = f'0.5 * (_I1 * _I1 - {vdt})'
elif sde_type == constants.STRA_SDE:
I2 = f'0.5 * _I1 * _I1'
else:
raise ValueError(f'Unknown SDE type: {sde_type}. We only supports {constants.SUPPORTED_SDE_TYPE}.')

code_string = f'''
f_x = f({", ".join(variables + parameters)}) # shape = (d, ..)
g_x = g({", ".join(variables + parameters)}) # shape = (d, ..)
_shape = backend.shape(g_x)
_d = _shape[0]
_D = _shape[1:]

# single Ito integrals
_I1 = backend.normal(0., {vdt}_sqrt, _D) # shape = (..)
# double Ito integrals
_I2 = {I2} # shape = (..)

# numerical integration #
# --------------------- #
g_H1_k = g_x * _I2 / {vdt}_sqrt # shape (d, ..)
H2 = x + g_H1_k # shape (d, ..)
H3 = x - g_H1_k # shape (d, ..)

g1 = g_x * _I1 # shape (d, ..)
x_new = x + f_x + g1 # shape (d, ..)
x_new += 0.5 * {vdt}_sqrt * g(H2, {", ".join(parameters)})
x_new -= 0.5 * {vdt}_sqrt * g(H3, {", ".join(parameters)})
'''
code_splits = code_string.split('\n')
code_lines.extend(code_splits)


def _srk1_wrapper(f, g, dt, sde_type, var_type, wiener_type, show_code, num_iter):
vdt, variables, parameters, arguments, func_name = common.basic_info(f=f, g=g)

# 1. code scope
code_scope = {'f': f, 'g': g, vdt: dt, f'{vdt}_sqrt': dt ** 0.5,
'backend': backend, 'num_iter': num_iter}

# 2. code lines
code_lines = [f'def {func_name}({", ".join(arguments)}):']

if var_type == constants.SYSTEM_VAR:
if len(variables) > 1:
raise ValueError(f'SDE with {constants.SYSTEM_VAR} variable type only '
f'supports one system variable. But we got {variables}.')

if wiener_type == constants.SCALAR_WIENER:
_srk1_system_var_with_scalar_wiener(sde_type, code_lines, variables, parameters, vdt)
elif wiener_type == constants.VECTOR_WIENER:
_srk1_system_var_with_vector_wiener(sde_type, code_lines, variables, parameters, vdt)
else:
raise ValueError(f'Unknown Wiener type: {wiener_type}, we only '
f'supports {constants.SUPPORTED_WIENER_TYPE}')

elif var_type == constants.SCALAR_VAR:
if wiener_type == constants.SCALAR_WIENER:
_srk2_pop_or_scalar_var_scalar_wiener(sde_type, code_lines, variables, parameters, vdt)
elif wiener_type == constants.VECTOR_WIENER:
_srk1_scalar_var_with_vector_wiener(sde_type, code_lines, variables, parameters, vdt)
else:
raise ValueError(f'Unknown Wiener type: {wiener_type}, we only '
f'supports {constants.SUPPORTED_WIENER_TYPE}')

elif var_type == constants.POPU_VAR:
if wiener_type == constants.SCALAR_WIENER:
_srk2_pop_or_scalar_var_scalar_wiener(sde_type, code_lines, variables, parameters, vdt)
elif wiener_type == constants.VECTOR_WIENER:
_srk2_pop_var_vector_wiener(sde_type, code_lines, variables, parameters, vdt)
else:
raise ValueError(f'Unknown Wiener type: {wiener_type}, we only '
f'supports {constants.SUPPORTED_WIENER_TYPE}')

else:
raise ValueError(f'Unknown var type: {var_type}, we only '
f'supports {constants.SUPPORTED_VAR_TYPE}')

# return and compile
common.return_compile_and_assign_attrs(code_lines, code_scope, show_code, variables)
return code_scope[func_name]


def _srk2_wrapper():
pass


def _wrap(wrapper, f, g, dt, sde_type, var_type, wiener_type, show_code, num_iter):
"""The base function to format a SRK method.

Parameters
----------
f : callable
The drift function of the SDE.
g : callable
The diffusion function of the SDE.
dt : float
The numerical precision.
sde_type : str
"utils.ITO_SDE" : Ito's Stochastic Calculus.
"utils.STRA_SDE" : Stratonovich's Stochastic Calculus.
wiener_type : str
var_type : str
"scalar" : with the shape of ().
"population" : with the shape of (N,) or (N1, N2) or (N1, N2, ...).
"system": with the shape of (d, ), (d, N), or (d, N1, N2).
show_code : bool
Whether show the formatted code.

Returns
-------
numerical_func : callable
The numerical function.
"""

sde_type = constants.ITO_SDE if sde_type is None else sde_type
assert sde_type in constants.SUPPORTED_SDE_TYPE, f'Currently, BrainPy only support SDE types: ' \
f'{constants.SUPPORTED_SDE_TYPE}. But we got {sde_type}.'

var_type = constants.POPU_VAR if var_type is None else var_type
assert var_type in constants.SUPPORTED_VAR_TYPE, f'Currently, BrainPy only supports variable types: ' \
f'{constants.SUPPORTED_VAR_TYPE}. But we got {var_type}.'

wiener_type = constants.SCALAR_WIENER if wiener_type is None else wiener_type
assert wiener_type in constants.SUPPORTED_WIENER_TYPE, f'Currently, BrainPy only supports Wiener ' \
f'Process types: {constants.SUPPORTED_WIENER_TYPE}. ' \
f'But we got {wiener_type}.'

show_code = False if show_code is None else show_code
dt = backend.get_dt() if dt is None else dt
num_iter = 10 if num_iter is None else num_iter

if f is not None and g is not None:
return wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type,
var_type=var_type, wiener_type=wiener_type, num_iter=num_iter)

elif f is not None:
return lambda g: wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type,
var_type=var_type, wiener_type=wiener_type, num_iter=num_iter)

elif g is not None:
return lambda f: wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type,
var_type=var_type, wiener_type=wiener_type, num_iter=num_iter)

else:
raise ValueError('Must provide "f" or "g".')


# ------------------
# Numerical methods
# ------------------


def srk1_strong(f=None, g=None, dt=None, sde_type=None, var_type=None, wiener_type=None, num_iter=None, show_code=None):
return _wrap(_srk1_wrapper, f=f, g=g, dt=dt, sde_type=sde_type, var_type=var_type,
wiener_type=wiener_type, show_code=show_code, num_iter=num_iter)


def srk2_strong(f=None, g=None, dt=None, sde_type=None, var_type=None, wiener_type=None, num_iter=None, show_code=None):
return _wrap(_srk2_wrapper, f=f, g=g, dt=dt, sde_type=sde_type, var_type=var_type,
wiener_type=wiener_type, show_code=show_code, num_iter=num_iter)

brainpy/integration/utils.py → brainpy/integrators/sympy_analysis.py View File

@@ -2,9 +2,18 @@

import ast
import math
from collections import Counter

import numpy as np
import sympy

from brainpy import errors
from brainpy import tools

try:
import sympy
except ModuleNotFoundError:
raise errors.PackageMissingError('Package "sympy" must be installed when the '
'users want to utilize the sympy analysis.')
import sympy.functions.elementary.complexes
import sympy.functions.elementary.exponential
import sympy.functions.elementary.hyperbolic
@@ -15,27 +24,10 @@ from sympy.codegen import cfunctions
from sympy.printing.precedence import precedence
from sympy.printing.str import StrPrinter

from .. import errors
from .. import profile
from .. import tools

__all__ = [
'FUNCTION_MAPPING',
'CONSTANT_MAPPING',
'Parser',
'Printer',
'str2sympy',
'sympy2str',
'get_mapping_scope',
'DiffEquationAnalyser',
'analyse_diff_eq',
]
CONSTANT_NOISE = 'CONSTANT'
FUNCTIONAL_NOISE = 'FUNCTIONAL'

FUNCTION_MAPPING = {
# 'real': sympy.functions.elementary.complexes.re,
# 'imag': sympy.functions.elementary.complexes.im,
# 'conjugate': sympy.functions.elementary.complexes.conjugate,

# functions in inherit python
# ---------------------------
'abs': sympy.functions.elementary.complexes.Abs,
@@ -58,9 +50,6 @@ FUNCTION_MAPPING = {
'expm1': cfunctions.expm1,
'exp2': cfunctions.exp2,

# 'maximum': sympy.functions.elementary.miscellaneous.Max,
# 'minimum': sympy.functions.elementary.miscellaneous.Min,

# functions in math
# ------------------
'asin': sympy.functions.elementary.trigonometric.asin,
@@ -99,65 +88,42 @@ CONSTANT_MAPPING = {
'inf': sympy.S.Infinity,
}

# Get functions in math
_functions_in_math = []
for key in dir(math):
if not key.startswith('__'):
_functions_in_math.append(getattr(math, key))

# Get functions in NumPy
_functions_in_numpy = []
for key in dir(np):
if not key.startswith('__'):
_functions_in_numpy.append(getattr(np, key))
for key in dir(np.random):
if not key.startswith('__'):
_functions_in_numpy.append(getattr(np.random, key))
for key in dir(np.linalg):
if not key.startswith('__'):
_functions_in_numpy.append(getattr(np.linalg, key))


def func_in_numpy_or_math(func):
return func in _functions_in_math or func in _functions_in_numpy


def get_mapping_scope():
if profile.run_on_cpu():
return {
'sign': np.sign, 'cos': np.cos, 'sin': np.sin, 'tan': np.tan,
'sinc': np.sinc, 'arcsin': np.arcsin, 'arccos': np.arccos,
'arctan': np.arctan, 'arctan2': np.arctan2, 'cosh': np.cosh,
'sinh': np.cosh, 'tanh': np.tanh, 'arcsinh': np.arcsinh,
'arccosh': np.arccosh, 'arctanh': np.arctanh, 'ceil': np.ceil,
'floor': np.floor, 'log': np.log, 'log2': np.log2, 'log1p': np.log1p,
'log10': np.log10, 'exp': np.exp, 'expm1': np.expm1, 'exp2': np.exp2,
'hypot': np.hypot, 'sqrt': np.sqrt, 'pi': np.pi, 'e': np.e, 'inf': np.inf,
'asin': math.asin, 'acos': math.acos, 'atan': math.atan, 'atan2': math.atan2,
'asinh': math.asinh, 'acosh': math.acosh, 'atanh': math.atanh,
# 'Max': np.maximum, 'Min': np.minimum
}
else:
return {
# functions in numpy
# ------------------
'arcsin': math.asin, 'arccos': math.acos,
'arctan': math.atan, 'arctan2': math.atan2, 'arcsinh': math.asinh,
'arccosh': math.acosh, 'arctanh': math.atanh,
'sign': np.sign, 'sinc': np.sinc,
'log2': np.log2, 'log1p': np.log1p,
'expm1': np.expm1, 'exp2': np.exp2,
# 'Max': max, 'Min': min,

# functions in math
# ------------------
'asin': math.asin,
'acos': math.acos,
'atan': math.atan,
'atan2': math.atan2,
'asinh': math.asinh,
'acosh': math.acosh,
'atanh': math.atanh,

# functions in both numpy and math
# --------------------------------
'cos': math.cos,
'sin': math.sin,
'tan': math.tan,
'cosh': math.cosh,
'sinh': math.sinh,
'tanh': math.tanh,
'log': math.log,
'log10': math.log10,
'sqrt': math.sqrt,
'exp': math.exp,
'hypot': math.hypot,
'ceil': math.ceil,
'floor': math.floor,

# constants in both numpy and math
# --------------------------------
'pi': math.pi,
'e': math.e,
'inf': math.inf}
return {
'sign': np.sign, 'cos': np.cos, 'sin': np.sin, 'tan': np.tan,
'sinc': np.sinc, 'arcsin': np.arcsin, 'arccos': np.arccos,
'arctan': np.arctan, 'arctan2': np.arctan2, 'cosh': np.cosh,
'sinh': np.cosh, 'tanh': np.tanh, 'arcsinh': np.arcsinh,
'arccosh': np.arccosh, 'arctanh': np.arctanh, 'ceil': np.ceil,
'floor': np.floor, 'log': np.log, 'log2': np.log2, 'log1p': np.log1p,
'log10': np.log10, 'exp': np.exp, 'expm1': np.expm1, 'exp2': np.exp2,
'hypot': np.hypot, 'sqrt': np.sqrt, 'pi': np.pi, 'e': np.e, 'inf': np.inf,
'asin': math.asin, 'acos': math.acos, 'atan': math.atan, 'atan2': math.atan2,
'asinh': math.asinh, 'acosh': math.acosh, 'atanh': math.atanh,
}


class Parser(object):
@@ -371,8 +337,9 @@ class Parser(object):

class Printer(StrPrinter):
"""
Printer that overrides the printing of some basic sympy objects. reversal_potential.g.
print "a and b" instead of "And(a, b)".
Printer that overrides the printing of some basic sympy objects.

e.g. print "a and b" instead of "And(a, b)".
"""

def _print_And(self, expr):
@@ -425,136 +392,225 @@ def sympy2str(sympy_expr):
return _PRINTER.doprint(sympy_expr)


class DiffEquationAnalyser(ast.NodeTransformer):
expression_ops = {
'Add': '+', 'Sub': '-', 'Mult': '*', 'Div': '/',
'Mod': '%', 'Pow': '**', 'BitXor': '^', 'BitAnd': '&',
}
class Expression(object):
def __init__(self, var, code):
self.var_name = var
self.code = code.strip()
self.substituted_code = None

def __init__(self):
self.variables = []
self.expressions = []
self.f_expr = None
self.g_expr = None
self.returns = []
self.return_type = None

# TODO : Multiple assignment like "a = b = 1" or "a, b = f()"
def visit_Assign(self, node):
targets = node.targets
try:
assert len(targets) == 1
except AssertionError:
raise errors.DiffEquationError('BrainPy currently does not support multiple '
'assignment in differential equation.')
self.variables.append(targets[0].id)
self.expressions.append(tools.ast2code(ast.fix_missing_locations(node.value)))
return node

def visit_AugAssign(self, node):
var = node.target.id
self.variables.append(var)
op = tools.ast2code(ast.fix_missing_locations(node.op))
expr = tools.ast2code(ast.fix_missing_locations(node.value))
self.expressions.append(f"{var} {op} {expr}")
return node

def visit_AnnAssign(self, node):
raise errors.DiffEquationError('Do not support an assignment with a type annotation.')

def visit_Return(self, node):
value = node.value
if isinstance(value, (ast.Tuple, ast.List)): # a tuple/list return
v0 = value.elts[0]
if isinstance(v0, (ast.Tuple, ast.List)): # item 0 is a tuple/list
# f expression
if isinstance(v0.elts[0], ast.Name):
self.f_expr = ('_f_res_', v0.elts[0].id)
else:
self.f_expr = ('_f_res_', tools.ast2code(ast.fix_missing_locations(v0.elts[0])))

if len(v0.elts) == 1:
self.return_type = '(x,),'
elif len(v0.elts) == 2:
self.return_type = '(x,x),'
# g expression
if isinstance(v0.elts[1], ast.Name):
self.g_expr = ('_g_res_', v0.elts[1].id)
else:
self.g_expr = ('_g_res_', tools.ast2code(ast.fix_missing_locations(v0.elts[1])))
else:
raise errors.DiffEquationError(f'The dxdt should have the format of (f, g), not '
f'"({tools.ast2code(ast.fix_missing_locations(v0.elts))})"')

# returns
for i, item in enumerate(value.elts[1:]):
if isinstance(item, ast.Name):
self.returns.append(item.id)
else:
self.returns.append(tools.ast2code(ast.fix_missing_locations(item)))

else: # item 0 is not a tuple/list
# f expression
if isinstance(v0, ast.Name):
self.f_expr = ('_f_res_', v0.id)
else:
self.f_expr = ('_f_res_', tools.ast2code(ast.fix_missing_locations(v0)))

if len(value.elts) == 1:
self.return_type = 'x,'
elif len(value.elts) == 2:
self.return_type = 'x,x'
# g expression
if isinstance(value.elts[1], ast.Name):
self.g_expr = ('_g_res_', value.elts[1].id)
else:
self.g_expr = ("_g_res_", tools.ast2code(ast.fix_missing_locations(value.elts[1])))
else:
raise errors.DiffEquationError('Cannot parse return expression. It should have the '
'format of "(f, [g]), [return values]"')
else:
self.return_type = 'x'
if isinstance(value, ast.Name): # a name return
self.f_expr = ('_f_res_', value.id)
else: # an expression return
self.f_expr = ('_f_res_', tools.ast2code(ast.fix_missing_locations(value)))
return node
@property
def identifiers(self):
return tools.get_identifiers(self.code)

def visit_If(self, node):
raise errors.DiffEquationError('Do not support "if" statement in differential equation.')
def __str__(self):
return f'{self.var_name} = {self.code}'

def visit_IfExp(self, node):
raise errors.DiffEquationError('Do not support "if" expression in differential equation.')
def __repr__(self):
return self.__str__()

def visit_For(self, node):
raise errors.DiffEquationError('Do not support "for" loop in differential equation.')
def __eq__(self, other):
if not isinstance(other, Expression):
return NotImplemented
if self.code != other.code:
return False
if self.var_name != other.var_name:
return False
return True

def visit_While(self, node):
raise errors.DiffEquationError('Do not support "while" loop in differential equation.')
def __ne__(self, other):
return not self.__eq__(other)

def visit_Try(self, node):
raise errors.DiffEquationError('Do not support "try" handler in differential equation.')
def get_code(self, subs=True):
if subs:
if self.substituted_code is None:
return self.code
else:
return self.substituted_code
else:
return self.code

def visit_With(self, node):
raise errors.DiffEquationError('Do not support "with" block in differential equation.')

def visit_Raise(self, node):
raise errors.DiffEquationError('Do not support "raise" statement.')
class SingleDiffEq(object):
"""Single Differential Equation.

def visit_Delete(self, node):
raise errors.DiffEquationError('Do not support "del" operation.')
A differential equation is defined as the standard form:

dx/dt = f(x) + g(x) dW

def analyse_diff_eq(eq_code):
assert eq_code.strip() != ''
tree = ast.parse(eq_code)
analyser = DiffEquationAnalyser()
analyser.visit(tree)
Parameters
----------
var_name : str
The variable names.
variables : list
The code variables.
expressions : list
The code expressions for each line.
derivative_expr : str
The final derivative expression.
scope : dict
The code scope.
"""

res = tools.DictPlus(variables=analyser.variables,
expressions=analyser.expressions,
returns=analyser.returns,
return_type=analyser.return_type,
f_expr=analyser.f_expr,
g_expr=analyser.g_expr)
return res
def __init__(self, var_name, variables, expressions, derivative_expr, scope,
func_name):
self.func_name = func_name
# function scope
self.func_scope = scope

# differential variable name and time name
self.var_name = var_name
self.t_name = 't'

# analyse function code
self.expressions = [Expression(v, expr) for v, expr in zip(variables, expressions)]
self.f_expr = Expression('_f_res_', derivative_expr)
for k, num in Counter(variables).items():
if num > 1:
raise errors.AnalyzerError(
f'Found "{k}" {num} times. Please assign each expression '
f'in differential function with a unique name. ')

def _substitute(self, final_exp, expressions, substitute_vars=None):
"""Substitute expressions to get the final single expression

Parameters
----------
final_exp : Expression
The final expression.
expressions : list, tuple
The list/tuple of expressions.
"""
if substitute_vars is None:
return
if final_exp is None:
return
assert substitute_vars == 'all' or \
substitute_vars == self.var_name or \
isinstance(substitute_vars, (tuple, list))

# Goal: Substitute dependent variables into the expresion
# Hint: This step doesn't require the left variables are unique
dependencies = {}
for expr in expressions:
substitutions = {}
for dep_var, dep_expr in dependencies.items():
if dep_var in expr.identifiers:
code = dep_expr.get_code(subs=True)
substitutions[sympy.Symbol(dep_var, real=True)] = str2sympy(code).expr
if len(substitutions):
new_sympy_expr = str2sympy(expr.code).expr.xreplace(substitutions)
new_str_expr = sympy2str(new_sympy_expr)
expr.substituted_code = new_str_expr
dependencies[expr.var_name] = expr
else:
if substitute_vars == 'all':
dependencies[expr.var_name] = expr
elif substitute_vars == self.var_name:
if self.var_name in expr.identifiers:
dependencies[expr.var_name] = expr
else:
ids = expr.identifiers
for var in substitute_vars:
if var in ids:
dependencies[expr.var_name] = expr
break

# Goal: get the final differential equation
# Hint: the step requires the expression variables must be unique
substitutions = {}
for dep_var, dep_expr in dependencies.items():
code = dep_expr.get_code(subs=True)
substitutions[sympy.Symbol(dep_var, real=True)] = str2sympy(code).expr
if len(substitutions):
new_sympy_expr = str2sympy(final_exp.code).expr.xreplace(substitutions)
new_str_expr = sympy2str(new_sympy_expr)
final_exp.substituted_code = new_str_expr

def get_f_expressions(self, substitute_vars=None):
if self.f_expr is None:
return []
self._substitute(self.f_expr, self.expressions, substitute_vars=substitute_vars)

return_expressions = []
# the derivative expression
dif_eq_code = self.f_expr.get_code(subs=True)
return_expressions.append(Expression(f'_df{self.var_name}_dt', dif_eq_code))
# needed variables
need_vars = tools.get_identifiers(dif_eq_code)
# get the total return expressions
for expr in self.expressions[::-1]:
if expr.var_name in need_vars:
if expr.substituted_code is None:
code = expr.code
else:
code = expr.substituted_code
return_expressions.append(Expression(expr.var_name, code))
need_vars |= tools.get_identifiers(code)
return return_expressions[::-1]

def _replace_expressions(self, expressions, name, y_sub, t_sub=None):
"""Replace expressions of df part.

Parameters
----------
expressions : list, tuple
The list/tuple of expressions.
name : str
The name of the new expression.
y_sub : str
The new name of the variable "y".
t_sub : str, optional
The new name of the variable "t".

Returns
-------
list_of_expr : list
A list of expressions.
"""
return_expressions = []

# replacements
replacement = {self.var_name: y_sub}
if t_sub is not None:
replacement[self.t_name] = t_sub

# replace variables in expressions
for expr in expressions:
replace = False
identifiers = expr.identifiers
for repl_var in replacement.keys():
if repl_var in identifiers:
replace = True
break
if replace:
code = tools.word_replace(expr.code, replacement)
new_expr = Expression(f"{expr.var_name}_{name}", code)
return_expressions.append(new_expr)
replacement[expr.var_name] = new_expr.var_name
return return_expressions

def replace_f_expressions(self, name, y_sub, t_sub=None):
"""Replace expressions of df part.

Parameters
----------
name : str
The name of the new expression.
y_sub : str
The new name of the variable "y".
t_sub : str, optional
The new name of the variable "t".

Returns
-------
list_of_expr : list
A list of expressions.
"""
return self._replace_expressions(self.get_f_expressions(),
name=name,
y_sub=y_sub,
t_sub=t_sub)

@property
def expr_names(self):
return [expr.var_name for expr in self.expressions]

+ 111
- 0
brainpy/integrators/utils.py View File

@@ -0,0 +1,111 @@
# -*- coding: utf-8 -*-

import inspect
from copy import deepcopy

from brainpy import backend
from brainpy import errors

__all__ = [
'numba_func',
'get_args',
]


def numba_func(code_scope, funcs_to_jit):
if backend.get_backend() in ['numba', 'numba-parallel']:
from brainpy.backend.runners.numba_cpu_runner import NUMBA_PROFILE
import numba as nb

profiles = deepcopy(NUMBA_PROFILE)
profiles.pop('parallel')
if isinstance(funcs_to_jit, str):
funcs_to_jit = [funcs_to_jit]
for f in funcs_to_jit:
code_scope[f] = nb.jit(**profiles)(code_scope[f])

elif backend.get_backend() == 'numba-cuda':
from numba import cuda

for f in funcs_to_jit:
code_scope[f] = cuda.jit(code_scope[f], device=True)


def get_args(f):
"""Get the function arguments.

>>> def f1(a, b, t, *args, c=1): pass
>>> get_args(f1)
(['a', 'b'], ['t', '*args', 'c'], ['a', 'b', 't', '*args', 'c=1'])

>>> def f2(a, b, *args, c=1, **kwargs): pass
>>> get_args(f2)
ValueError: Don not support dict of keyword arguments: **kwargs

>>> def f3(a, b, t, c=1, d=2): pass
>>> get_args(f4)
(['a', 'b'], ['t', 'c', 'd'], ['a', 'b', 't', 'c=1', 'd=2'])

>>> def f4(a, b, t, *args): pass
>>> get_args(f4)
(['a', 'b'], ['t', '*args'], ['a', 'b', 't', '*args'])

>>> scope = {}
>>> exec(compile('def f5(a, b, t, *args): pass', '', 'exec'), scope)
>>> get_args(scope['f5'])
(['a', 'b'], ['t', '*args'], ['a', 'b', 't', '*args'])

Parameters
----------
f : callable
The function.

Returns
-------
args : tuple
The variable names, the other arguments, and the original args.
"""

# 1. get the function arguments
reduced_args = []
original_args = []

for name, par in inspect.signature(f).parameters.items():
if par.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD:
reduced_args.append(par.name)

elif par.kind is inspect.Parameter.VAR_POSITIONAL:
reduced_args.append(f'*{par.name}')

elif par.kind is inspect.Parameter.KEYWORD_ONLY:
reduced_args.append(par.name)

elif par.kind is inspect.Parameter.POSITIONAL_ONLY:
raise errors.DiffEqError('Don not support positional only parameters, e.g., /')
elif par.kind is inspect.Parameter.VAR_KEYWORD:
raise errors.DiffEqError(f'Don not support dict of keyword arguments: {str(par)}')
else:
raise errors.DiffEqError(f'Unknown argument type: {par.kind}')

original_args.append(str(par))

# 2. analyze the function arguments
# 2.1 class keywords
class_kw = []
if reduced_args[0] in backend.CLASS_KEYWORDS:
class_kw.append(reduced_args[0])
reduced_args = reduced_args[1:]
for a in reduced_args:
if a in backend.CLASS_KEYWORDS:
raise errors.DiffEqError(f'Class keywords "{a}" must be defined '
f'as the first argument.')
# 2.2 variable names
var_names = []
for a in reduced_args:
if a == 't':
break
var_names.append(a)
else:
raise ValueError('Do not find time variable "t".')
other_args = reduced_args[len(var_names):]
return class_kw, var_names, other_args, original_args

+ 11
- 4
brainpy/measure.py View File

@@ -1,9 +1,13 @@
# -*- coding: utf-8 -*-

import numpy as np
from numba import njit

from . import profile
from brainpy import backend

try:
from numba import njit
except ModuleNotFoundError:
njit = None

__all__ = [
'cross_correlation',
@@ -18,13 +22,16 @@ __all__ = [
###############################


@njit
def _cc(states, i, j):
sqrt_ij = np.sqrt(np.sum(states[i]) * np.sum(states[j]))
k = 0. if sqrt_ij == 0. else np.sum(states[i] * states[j]) / sqrt_ij
return k


if njit is None:
_cc = njit(_cc)


def cross_correlation(spikes, bin_size):
"""Calculate cross correlation index between neurons.

@@ -207,7 +214,7 @@ def firing_rate(sp_matrix, width, window='gaussian'):
rate = np.sum(sp_matrix, axis=1)

# window
dt = profile.get_dt()
dt = backend.get_dt()
if window == 'gaussian':
width1 = 2 * width / dt
width2 = int(np.around(width1))


+ 0
- 377
brainpy/profile.py View File

@@ -1,377 +0,0 @@
# -*- coding: utf-8 -*-

"""
The setting of the overall framework by ``profile.py`` API.
"""

from numba import cuda

__all__ = [
'set',

'run_on_cpu',
'run_on_gpu',

'set_backend',
'get_backend',

'set_device',
'get_device',

'set_dt',
'get_dt',

'set_numerical_method',
'get_numerical_method',

'set_numba_profile',
'get_numba_profile',

'get_num_thread_gpu',

'is_jit',
'is_merge_integrators',
'is_merge_steps',
'is_substitute_equation',
'show_code_scope',
'show_format_code',
]

_jit = False
_backend = 'numpy'
_device = 'cpu'
_dt = 0.1
_method = 'euler'
_numba_setting = {
'nopython': True,
'fastmath': True,
'nogil': True,
'parallel': False
}
_show_format_code = False
_show_code_scope = False
_substitute_equation = False
_merge_integrators = True
_merge_steps = False
_num_thread_gpu = None


def set(
jit=None,
device=None,
numerical_method=None,
dt=None,
float_type=None,
int_type=None,
merge_integrators=None,
merge_steps=None,
substitute=None,
show_code=None,
show_code_scope=None
):
# JIT and device
if device is not None and jit is None:
assert isinstance(device, str), "'device' must a string."
set_device(_jit, device=device)
if jit is not None:
assert isinstance(jit, bool), "'jit' must be True or False."
if device is not None:
assert isinstance(device, str), "'device' must a string."
set_device(jit, device=device)

# numerical integration method
if numerical_method is not None:
assert isinstance(numerical_method, str), '"numerical_method" must be a string.'
set_numerical_method(numerical_method)

# numerical integration precision
if dt is not None:
assert isinstance(dt, (float, int)), '"dt" must be float or int.'
set_dt(dt)

# default float type
if float_type is not None:
from .backend import _set_default_float
_set_default_float(float_type)

# default int type
if int_type is not None:
from .backend import _set_default_int
_set_default_int(int_type)

# option to merge integral functions
if merge_integrators is not None:
assert isinstance(merge_integrators, bool), '"merge_integrators" must be True or False.'
if run_on_gpu() and not merge_integrators:
raise ValueError('GPU mode do not support "merge_integrators = False".')
global _merge_integrators
_merge_integrators = merge_integrators

# option to merge step functions
if merge_steps is not None:
assert isinstance(merge_steps, bool), '"merge_steps" must be True or False.'
global _merge_steps
_merge_steps = merge_steps

# option of the equation substitution
if substitute is not None:
assert isinstance(substitute, bool), '"substitute" must be True or False.'
global _substitute_equation
_substitute_equation = substitute

# option of the formatted code output
if show_code is not None:
assert isinstance(show_code, bool), '"show_code" must be True or False.'
global _show_format_code
_show_format_code = show_code

# option of the formatted code scope
if show_code_scope is not None:
assert isinstance(show_code_scope, bool), '"show_code_scope" must be True or False.'
global _show_code_scope
_show_code_scope = show_code_scope


def set_device(jit, device=None):
"""Set the backend and the device to deploy the models.

Parameters
----------
jit : bool
Whether use the jit acceleration.
device : str, optional
The device name.
"""

# jit
# ---

global _jit

if _jit != jit:
_jit = jit

# device
# ------

global _device
global _num_thread_gpu

if device is None:
return

device = device.lower()
if _device != device:
if not jit:
if device != 'cpu':
print(f'Non-JIT mode now only supports "cpu" device, not "{device}".')
else:
_device = device
else:
if device == 'cpu':
set_numba_profile(parallel=False)
elif device == 'multi-cpu':
set_numba_profile(parallel=True)
else:
if device.startswith('gpu'):
# get cuda id
cuda_id = device.replace('gpu', '')
if cuda_id == '':
cuda_id = 0
device = f'{device}0'
else:
cuda_id = float(cuda_id)

# set cuda
if cuda.is_available():
cuda.select_device(cuda_id)
else:
raise ValueError('Cuda is not available. Cannot set gpu backend.')

gpu = cuda.get_current_device()
_num_thread_gpu = gpu.MAX_THREADS_PER_BLOCK

else:
raise ValueError(f'Unknown device in Numba mode: {device}.')
_device = device


def get_device():
"""Get the device name.

Returns
-------
device: str
Device name.

"""
return _device


def is_jit():
"""Check whether the backend is ``numba``.

Returns
-------
jit : bool
True or False.
"""
return _jit


def run_on_cpu():
"""Check whether the device is "CPU".

Returns
-------
device : bool
True or False.
"""
return _device.endswith('cpu')


def run_on_gpu():
"""Check whether the device is "GPU".

Returns
-------
device : bool
True or False.
"""
return _device.startswith('gpu')


def set_backend(backend):
"""Set the running backend.

Parameters
----------
backend : str
The backend name.
"""
if backend not in ['numpy', 'pytorch']:
raise ValueError(f'BrainPy now supports "numpy" or "pytorch" backend, not "{backend}".')

global _backend
_backend = backend


def get_backend():
"""Get the used backend of BrainPy.

Returns
-------
backend : str
The backend name.
"""
return _backend


def set_numba_profile(**kwargs):
"""Set the compilation options of Numba JIT function.

Parameters
----------
kwargs : Any
The arguments, including ``cache``, ``fastmath``,
``parallel``, ``nopython``.
"""
global _numba_setting

if 'fastmath' in kwargs:
_numba_setting['fastmath'] = kwargs.pop('fastmath')
if 'nopython' in kwargs:
_numba_setting['nopython'] = kwargs.pop('nopython')
if 'nogil' in kwargs:
_numba_setting['nogil'] = kwargs.pop('nogil')
if 'parallel' in kwargs:
_numba_setting['parallel'] = kwargs.pop('parallel')


def get_numba_profile():
"""Get the compilation setting of numba JIT function.

Returns
-------
numba_setting : dict
Numba setting.
"""
return _numba_setting


def set_dt(dt):
"""Set the numerical integrator precision.

Parameters
----------
dt : float
Numerical integration precision.
"""
assert isinstance(dt, float)
global _dt
_dt = dt


def get_dt():
"""Get the numerical integrator precision.

Returns
-------
dt : float
Numerical integration precision.
"""
return _dt


def set_numerical_method(method):
"""Set the default numerical integrator method for differential equations.

Parameters
----------
method : str, callable
Numerical integrator method.
"""
from brainpy.integration import _SUPPORT_METHODS

if not isinstance(method, str):
raise ValueError(f'Only support string, not {type(method)}.')
if method not in _SUPPORT_METHODS:
raise ValueError(f'Unsupported numerical method: {method}.')

global _method
_method = method


def get_numerical_method():
"""Get the default numerical integrator method.

Returns
-------
method : str
The default numerical integrator method.
"""
return _method


def is_merge_integrators():
return _merge_integrators


def is_merge_steps():
return _merge_steps


def is_substitute_equation():
return _substitute_equation


def show_code_scope():
return _show_code_scope


def show_format_code():
return _show_format_code


def get_num_thread_gpu():
return _num_thread_gpu

+ 10
- 0
brainpy/simulation/__init__.py View File

@@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-

from .brain_objects import *
from .constants import *
from .delay import *
from .dynamic_system import *
from .monitors import *
from .runner import *
from .utils import *


+ 266
- 0
brainpy/simulation/brain_objects.py View File

@@ -0,0 +1,266 @@
# -*- coding: utf-8 -*-

from collections import OrderedDict

from brainpy import backend
from brainpy import errors
from brainpy.simulation import delay
from brainpy.simulation import utils
from brainpy.simulation.dynamic_system import DynamicSystem

__all__ = [
'NeuGroup',
'SynConn',
'TwoEndConn',
'Network',
]

_NeuGroup_NO = 0
_TwoEndSyn_NO = 0


class NeuGroup(DynamicSystem):
"""Neuron Group.

Parameters
----------
steps : NeuType
The instantiated neuron type model.
size : int, tuple
The neuron group geometry.
monitors : list, tuple
Variables to monitor.
name : str
The name of the neuron group.
"""

def __init__(self, size, monitors=None, name=None, show_code=False):
# name
# -----
if name is None:
name = ''
else:
name = '_' + name
global _NeuGroup_NO
_NeuGroup_NO += 1
name = f'NG{_NeuGroup_NO}{name}'

# size
# ----
if isinstance(size, (list, tuple)):
if len(size) <= 0:
raise errors.ModelDefError('size must be int, or a tuple/list of int.')
if not isinstance(size[0], int):
raise errors.ModelDefError('size must be int, or a tuple/list of int.')
size = tuple(size)
elif isinstance(size, int):
size = (size,)
else:
raise errors.ModelDefError('size must be int, or a tuple/list of int.')
self.size = size

# initialize
# ----------
super(NeuGroup, self).__init__(steps={'update': self.update},
monitors=monitors,
name=name,
show_code=show_code)

def update(self, *args):
raise NotImplementedError


class SynConn(DynamicSystem):
"""Synaptic Connections.
"""

def __init__(self, steps, monitors=None, name=None, show_code=False):
# check delay update
if callable(steps):
steps = OrderedDict([(steps.__name__, steps)])
elif isinstance(steps, (tuple, list)) and callable(steps[0]):
steps = OrderedDict([(step.__name__, step) for step in steps])
else:
assert isinstance(steps, dict)

if hasattr(self, 'constant_delays'):
for key, delay_var in self.constant_delays.items():
if delay_var.update not in steps:
delay_name = f'{key}_delay_update'
setattr(self, delay_name, delay_var.update)
steps[delay_name] = delay_var.update

# initialize super class
super(SynConn, self).__init__(steps=steps, monitors=monitors, name=name, show_code=show_code)

# delay assignment
if hasattr(self, 'constant_delays'):
for key, delay_var in self.constant_delays.items():
delay_var.name = f'{self.name}_delay_{key}'

def register_constant_delay(self, key, size, delay_time):
if not hasattr(self, 'constant_delays'):
self.constant_delays = {}
if key in self.constant_delays:
raise errors.ModelDefError(f'"{key}" has been registered as an constant delay.')
self.constant_delays[key] = delay.ConstantDelay(size, delay_time)
return self.constant_delays[key]

def update(self, *args):
raise NotImplementedError


class TwoEndConn(SynConn):
"""Two End Synaptic Connections.

Parameters
----------
steps : SynType
The instantiated neuron type model.
pre : neurons.NeuGroup, neurons.NeuSubGroup
Pre-synaptic neuron group.
post : neurons.NeuGroup, neurons.NeuSubGroup
Post-synaptic neuron group.
monitors : list, tuple
Variables to monitor.
name : str
The name of the neuron group.
"""

def __init__(self, pre, post, monitors=None, name=None, show_code=False):
# name
# ----
if name is None:
name = ''
else:
name = '_' + name
global _TwoEndSyn_NO
_TwoEndSyn_NO += 1
name = f'TEC{_TwoEndSyn_NO}{name}'

# pre or post neuron group
# ------------------------
if not isinstance(pre, NeuGroup):
raise errors.ModelUseError('"pre" must be an instance of NeuGroup.')
self.pre = pre
if not isinstance(post, NeuGroup):
raise errors.ModelUseError('"post" must be an instance of NeuGroup.')
self.post = post

# initialize
# ----------
super(TwoEndConn, self).__init__(steps={'update': self.update},
name=name,
monitors=monitors,
show_code=show_code)


class Network(object):
"""The main simulation controller in ``BrainPy``.

``Network`` handles the running of a simulation. It contains a set
of objects that are added with `add()`. The `run()` method actually
runs the simulation. The main loop runs according to user add orders.
The objects in the `Network` are accessible via their names, e.g.
`net.name` would return the `object`.
"""

def __init__(self, *args, show_code=False, **kwargs):
# record the current step
self.t_start = 0.
self.t_end = 0.

# store all nodes
self.all_nodes = OrderedDict()

# store the step function
self.run_func = None
self.show_code = show_code

# add nodes
self.add(*args, **kwargs)

def __getattr__(self, item):
if item in self.all_nodes:
return self.all_nodes[item]
else:
return super(Network, self).__getattribute__(item)

def _add_obj(self, obj, name=None):
# 1. check object type
if not isinstance(obj, DynamicSystem):
raise ValueError(f'Unknown object type "{type(obj)}". '
f'Currently, Network only supports '
f'{NeuGroup.__name__} and '
f'{TwoEndConn.__name__}.')
# 2. check object name
name = obj.name if name is None else name
if name in self.all_nodes:
raise KeyError(f'Name "{name}" has been used in the network, '
f'please change another name.')
# 3. add object to the network
self.all_nodes[name] = obj
if obj.name != name:
self.all_nodes[obj.name] = obj

def add(self, *args, **kwargs):
"""Add object (neurons or synapses) to the network.

Parameters
----------
args
The nameless objects.
kwargs
The named objects, which can be accessed by `net.xxx`
(xxx is the name of the object).
"""
for obj in args:
self._add_obj(obj)
for name, obj in kwargs.items():
self._add_obj(obj, name)

def run(self, duration, inputs=(), report=False, report_percent=0.1):
"""Run the simulation for the given duration.

This function provides the most convenient way to run the network.
For example:

Parameters
----------
duration : int, float, tuple, list
The amount of simulation time to run for.
inputs : list, tuple
The receivers, external inputs and durations.
report : bool
Report the progress of the simulation.
report_percent : float
The speed to report simulation progress.
"""
# preparation
start, end = utils.check_duration(duration)
dt = backend.get_dt()
ts = backend.arange(start, end, dt)

# build the network
run_length = ts.shape[0]
format_inputs = utils.format_net_level_inputs(inputs, run_length)
net_runner = backend.get_net_runner()(all_nodes=self.all_nodes)
self.run_func = net_runner.build(run_length=run_length,
formatted_inputs=format_inputs,
return_code=False,
show_code=self.show_code)

# run the network
utils.run_model(self.run_func, times=ts, report=report, report_percent=report_percent)

# end
self.t_start, self.t_end = start, end
for obj in self.all_nodes.values():
if len(obj.mon['vars']) > 0:
obj.mon['ts'] = ts

@property
def ts(self):
"""Get the time points of the network.
"""
return backend.arange(self.t_start, self.t_end, backend.get_dt())

+ 16
- 0
brainpy/simulation/constants.py View File

@@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-


UNKNOWN_TYPE = 'unknown' # name of the neuron group
NEU_GROUP_TYPE = 'NeuGroup' # name of the neuron group
SYN_CONN_TYPE = 'SynConn' # name of the synapse connection
TWO_END_TYPE = 'TwoEndConn' # name of the two-end synaptic connection
SUPPORTED_TYPES = [NEU_GROUP_TYPE, SYN_CONN_TYPE, TWO_END_TYPE, UNKNOWN_TYPE]

# input operations
SUPPORTED_INPUT_OPS = {'-': 'sub',
'+': 'add',
'x': 'mul',
'*': 'mul',
'/': 'div',
'=': 'assign'}

+ 62
- 0
brainpy/simulation/delay.py View File

@@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-

import math

from brainpy import backend

__all__ = [
'ConstantDelay',
'push_type1',
'push_type2',
'pull_type0',
'pull_type1',
]


class ConstantDelay(object):
"""Constant delay variable for synapse computation.

"""

def __init__(self, size, delay_time):
self.delay_time = delay_time
self.delay_num_step = int(math.ceil(delay_time / backend.get_dt())) + 1
self.delay_in_idx = 0
self.delay_out_idx = self.delay_num_step - 1

if isinstance(size, int):
size = (size,)
size = tuple(size)
self.delay_data = backend.zeros((self.delay_num_step + 1,) + size)

def push(self, idx_or_val, value=None):
if value is None:
self.delay_data[self.delay_in_idx] = idx_or_val
else:
self.delay_data[self.delay_in_idx][idx_or_val] = value

def pull(self, idx=None):
if idx is None:
return self.delay_data[self.delay_out_idx]
else:
return self.delay_data[self.delay_out_idx][idx]

def update(self):
self.delay_in_idx = (self.delay_in_idx + 1) % self.delay_num_step
self.delay_out_idx = (self.delay_out_idx + 1) % self.delay_num_step


def push_type1(idx_or_val, delay_data, delay_in_idx):
delay_data[delay_in_idx] = idx_or_val


def push_type2(idx_or_val, value, delay_data, delay_in_idx):
delay_data[delay_in_idx][idx_or_val] = value


def pull_type0(delay_data, delay_out_idx):
return delay_data[delay_out_idx]


def pull_type1(idx, delay_data, delay_out_idx):
return delay_data[delay_out_idx][idx]

+ 181
- 0
brainpy/simulation/dynamic_system.py View File

@@ -0,0 +1,181 @@
# -*- coding: utf-8 -*-

from collections import OrderedDict

from brainpy import backend
from brainpy import errors
from brainpy.simulation import utils
from brainpy.simulation.monitors import Monitor

__all__ = [
'DynamicSystem',
]

_DynamicSystem_NO = 0


class DynamicSystem(object):
"""Base Dynamic System Class.

Parameters
----------
steps : callable, list of callable, dict
The callable function, or a list of callable functions.
monitors : list, tuple, None
Variables to monitor.
name : str
The name of the dynamic system.
host : any
The host to store data, including variables, functions, etc.
show_code : bool
Whether show the formatted codes.
"""

target_backend = None

def __init__(self, steps, monitors=None, name=None, host=None, show_code=False):
# host of the data
# ----------------
if host is None:
host = self
self.host = host

# model
# -----
if callable(steps):
self.steps = OrderedDict([(steps.__name__, steps)])
elif isinstance(steps, (list, tuple)) and callable(steps[0]):
self.steps = OrderedDict([(step.__name__, step) for step in steps])
elif isinstance(steps, dict):
self.steps = steps
else:
raise errors.ModelDefError(f'Unknown model type: {type(steps)}. Currently, BrainPy '
f'only supports: function, list/tuple/dict of functions.')

# name
# ----
if name is None:
global _DynamicSystem_NO
name = f'DS{_DynamicSystem_NO}'
_DynamicSystem_NO += 1
if not name.isidentifier():
raise errors.ModelUseError(f'"{name}" isn\'t a valid identifier according to Python '
f'language definition. Please choose another name.')
self.name = name

# monitors
# ---------
if monitors is None:
monitors = []
self.mon = Monitor(monitors)
for var in self.mon['vars']:
if not hasattr(self, var):
raise errors.ModelDefError(f"Item {var} isn't defined in model {self}, "
f"so it can not be monitored.")

# runner
# -------
self.runner = backend.get_node_runner()(pop=self)

# run function
# ------------
self.run_func = None

# others
# ---
self.show_code = show_code
if self.target_backend is None:
raise errors.ModelDefError('Must define "target_backend".')
if isinstance(self.target_backend, str):
self._target_backend = (self.target_backend,)
elif isinstance(self.target_backend, (tuple, list)):
if not isinstance(self.target_backend[0], str):
raise errors.ModelDefError('"target_backend" must be a list/tuple of string.')
self._target_backend = tuple(self.target_backend)
else:
raise errors.ModelDefError(f'Unknown setting of "target_backend": {self.target_backend}')

def build(self, inputs, inputs_is_formatted=False, return_code=True, mon_length=0, show_code=False):
"""Build the object for running.

Parameters
----------
inputs : list, tuple, optional
The object inputs.
inputs_is_formatted : bool
Whether the "inputs" is formatted.
return_code : bool
Whether return the formatted codes.
mon_length : int
The monitor length.

Returns
-------
calls : list, tuple
The code lines to call step functions.
"""
if (self._target_backend[0] != 'general') and \
(backend.get_backend() not in self._target_backend):
raise errors.ModelDefError(f'The model {self.name} is target to run on {self._target_backend},'
f'but currently the default backend of BrainPy is '
f'{backend.get_backend()}')
if not inputs_is_formatted:
inputs = utils.format_pop_level_inputs(inputs, self, mon_length)
return self.runner.build(formatted_inputs=inputs,
mon_length=mon_length,
return_code=return_code,
show_code=(self.show_code or show_code))

def run(self, duration, inputs=(), report=False, report_percent=0.1):
"""The running function.

Parameters
----------
duration : float, int, tuple, list
The running duration.
inputs : list, tuple
The model inputs with the format of ``[(key, value [operation])]``.
report : bool
Whether report the running progress.
report_percent : float
The percent of progress to report.
"""

# times
# ------
start, end = utils.check_duration(duration)
times = backend.arange(start, end, backend.get_dt())
run_length = backend.shape(times)[0]

# build run function
# ------------------
self.run_func = self.build(inputs, inputs_is_formatted=False, mon_length=run_length, return_code=False)

# run the model
# -------------
utils.run_model(self.run_func, times, report, report_percent)
self.mon['ts'] = times

def get_schedule(self):
"""Get the schedule (running order) of the update functions.

Returns
-------
schedule : list, tuple
The running order of update functions.
"""
return self.runner.get_schedule()

def set_schedule(self, schedule):
"""Set the schedule (running order) of the update functions.

For example, if the ``self.model`` has two step functions: `step1`, `step2`.
Then, you can set the shedule by using:

>>> pop = DynamicSystem(...)
>>> pop.set_schedule(['input', 'step1', 'step2', 'monitor'])
"""
self.runner.set_schedule(schedule)

def __str__(self):
return self.name

+ 53
- 0
brainpy/simulation/monitors.py View File

@@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-

from brainpy import backend
from brainpy import errors
from brainpy import tools

__all__ = [
'Monitor'
]


class Monitor(tools.DictPlus):
"""The basic Monitor class to store the past variable trajectories.
"""
def __init__(self, variables):
mon_items = []
mon_indices = []
item_content = {}
if variables is not None:
if isinstance(variables, (list, tuple)):
for var in variables:
if isinstance(var, str):
mon_items.append(var)
mon_indices.append(None)
item_content[var] = backend.zeros((1, 1))
elif isinstance(var, (tuple, list)):
mon_items.append(var[0])
mon_indices.append(var[1])
item_content[var[0]] = backend.zeros((1, 1))
else:
raise errors.ModelUseError(f'Unknown monitor item: {str(var)}')
elif isinstance(variables, dict):
for k, v in variables.items():
mon_items.append(k)
mon_indices.append(v)
item_content[k] = backend.zeros((1, 1))
else:
raise errors.ModelUseError(f'Unknown monitors type: {type(variables)}')
super(Monitor, self).__init__(ts=None,
vars=mon_items,
indices=mon_indices,
num_item=len(item_content),
**item_content)

def reshape(self, run_length):
for var in self['vars']:
val = self[var]
shape = backend.shape(val)
if run_length < shape[0]:
self[var] = val[:run_length]
elif run_length > shape[0]:
append = backend.zeros((run_length - shape[0],) + shape[1:])
self[var] = backend.vstack([val, append])

+ 66
- 0
brainpy/simulation/runner.py View File

@@ -0,0 +1,66 @@
# -*- coding: utf-8 -*-

import abc

from brainpy import errors

__all__ = [
'AbstractRunner',
'NodeRunner',
'NetRunner',
]


class AbstractRunner(abc.ABC):
"""
Abstract base class for backend runner.
"""
@abc.abstractmethod
def build(self, *args, **kwargs):
pass


class NodeRunner(AbstractRunner):
"""
Abstract Node Runner.
"""
def __init__(self, host, steps):
self.host = host
self.steps = steps
self.schedule = ['input'] + list(self.steps.keys()) + ['monitor']

def get_schedule(self):
return self.schedule

def set_schedule(self, schedule):
if not isinstance(schedule, (list, tuple)):
raise errors.ModelUseError('"schedule" must be a list/tuple.')
all_func_names = ['input', 'monitor'] + list(self.steps.keys())
for s in schedule:
if s not in all_func_names:
raise errors.ModelUseError(f'Unknown step function "{s}" for model "{self.host}".')
self.schedule = schedule

@abc.abstractmethod
def set_data(self, *args, **kwargs):
pass

@abc.abstractmethod
def get_input_func(self, *args, **kwargs):
pass

@abc.abstractmethod
def get_monitor_func(self, *args, **kwargs):
pass

@abc.abstractmethod
def get_steps_func(self, *args, **kwargs):
pass


class NetRunner(AbstractRunner):
"""
Abstract Network Runner.
"""
def __init__(self, all_nodes):
self.all_nodes = all_nodes

+ 240
- 0
brainpy/simulation/utils.py View File

@@ -0,0 +1,240 @@
# -*- coding: utf-8 -*-

import time

from brainpy import backend
from brainpy import errors
from brainpy.simulation import constants

__all__ = [
'check_duration',
'run_model',
'format_pop_level_inputs',
'format_net_level_inputs',
]


def check_duration(duration):
"""Check the running duration.

Parameters
----------
duration : int, list, tuple
The running duration, it can be an int (which represents the end
of the simulation), of a tuple/list of int (which represents the
[start, end] / [end, start] of the simulation).

Returns
-------
duration : tuple
The tuple of running duration includes (start, end).
"""
if isinstance(duration, (int, float)):
start, end = 0., duration
elif isinstance(duration, (tuple, list)):
assert len(duration) == 2, 'Only support duration setting with the ' \
'format of "(start, end)" or "end".'
start, end = duration
else:
raise ValueError(f'Unknown duration type: {type(duration)}. Currently, BrainPy only '
f'support duration specification with the format of "(start, end)" '
f'or "end".')

if start > end:
start, end = end, start
return start, end


def run_model(run_func, times, report, report_percent):
"""Run the model.

The "run_func" can be the step run function of a population, or a network.

Parameters
----------
run_func : callable
The step run function.
times : iterable
The model running times.
report : bool
Whether report the progress of the running.
report_percent : float
The percent of the total running length for each report.
"""
run_length = len(times)
dt = backend.get_dt()
if report:
t0 = time.time()
for i, t in enumerate(times[:1]):
run_func(_t=t, _i=i, _dt=dt)
print('Compilation used {:.4f} s.'.format(time.time() - t0))

print("Start running ...")
report_gap = int(run_length * report_percent)
t0 = time.time()
for run_idx in range(1, run_length):
run_func(_t=times[run_idx], _i=run_idx, _dt=dt)
if (run_idx + 1) % report_gap == 0:
percent = (run_idx + 1) / run_length * 100
print('Run {:.1f}% used {:.3f} s.'.format(percent, time.time() - t0))
print('Simulation is done in {:.3f} s.'.format(time.time() - t0))
print()
else:
for run_idx in range(run_length):
run_func(_t=times[run_idx], _i=run_idx, _dt=dt)


def format_pop_level_inputs(inputs, host, mon_length):
"""Format the inputs of a population.

Parameters
----------
inputs : tuple, list
The inputs of the population.
host : Population
The host which contains all data.
mon_length : int
The monitor length.

Returns
-------
formatted_inputs : tuple, list
The formatted inputs of the population.
"""
if inputs is None:
inputs = []
if not isinstance(inputs, (tuple, list)):
raise errors.ModelUseError('"inputs" must be a tuple/list.')
if len(inputs) > 0 and not isinstance(inputs[0], (list, tuple)):
if isinstance(inputs[0], str):
inputs = [inputs]
else:
raise errors.ModelUseError('Unknown input structure, only support inputs '
'with format of "(key, value, [operation])".')
for input in inputs:
if not 2 <= len(input) <= 3:
raise errors.ModelUseError('For each target, you must specify "(key, value, [operation])".')
if len(input) == 3 and input[2] not in constants.SUPPORTED_INPUT_OPS:
raise errors.ModelUseError(f'Input operation only supports '
f'"{list(constants.SUPPORTED_INPUT_OPS.keys())}", '
f'not "{input[2]}".')

# format inputs
# -------------
formatted_inputs = []
for input in inputs:
# key
if not isinstance(input[0], str):
raise errors.ModelUseError('For each input, input[0] must be a string '
'to specify variable of the target.')
key = input[0]
if not hasattr(host, key):
raise errors.ModelUseError(f'Input target key "{key}" is not defined in {host}.')

# value and data type
val = input[1]
if isinstance(input[1], (int, float)):
data_type = 'fix'
else:
shape = backend.shape(input[1])
if shape[0] == mon_length:
data_type = 'iter'
else:
data_type = 'fix'

# operation
if len(input) == 3:
ops = input[2]
else:
ops = '+'
if ops not in constants.SUPPORTED_INPUT_OPS:
raise errors.ModelUseError(f'Currently, BrainPy only support operations '
f'{list(constants.SUPPORTED_INPUT_OPS.keys())}, '
f'not {ops}')
# input
format_inp = (key, val, ops, data_type)
formatted_inputs.append(format_inp)

return formatted_inputs


def format_net_level_inputs(inputs, run_length):
"""Format the inputs of a network.

Parameters
----------
inputs : tuple
The inputs.
run_length : int
The running length.

Returns
-------
formatted_input : dict
The formatted input.
"""
from brainpy.simulation import brain_objects

# 1. format the inputs to standard
# formats and check the inputs
if not isinstance(inputs, (tuple, list)):
raise errors.ModelUseError('"inputs" must be a tuple/list.')
if len(inputs) > 0 and not isinstance(inputs[0], (list, tuple)):
if isinstance(inputs[0], brain_objects.DynamicSystem):
inputs = [inputs]
else:
raise errors.ModelUseError('Unknown input structure. Only supports '
'"(target, key, value, [operation])".')
for input in inputs:
if not 3 <= len(input) <= 4:
raise errors.ModelUseError('For each target, you must specify '
'"(target, key, value, [operation])".')
if len(input) == 4:
if input[3] not in constants.SUPPORTED_INPUT_OPS:
raise errors.ModelUseError(f'Input operation only supports '
f'"{list(constants.SUPPORTED_INPUT_OPS.keys())}", '
f'not "{input[3]}".')

# 2. format inputs
formatted_inputs = {}
for input in inputs:
# target
if isinstance(input[0], brain_objects.DynamicSystem):
target = input[0]
target_name = input[0].name
else:
raise KeyError(f'Unknown input target: {str(input[0])}')

# key
key = input[1]
if not isinstance(key, str):
raise errors.ModelUseError('For each input, input[1] must be a string '
'to specify variable of the target.')
if not hasattr(target, key):
raise errors.ModelUseError(f'Target {target} does not have key {key}. '
f'So, it can not assign input to it.')

# value and data type
val = input[2]
if isinstance(input[2], (int, float)):
data_type = 'fix'
else:
shape = backend.shape(val)
if shape[0] == run_length:
data_type = 'iter'
else:
data_type = 'fix'

# operation
if len(input) == 4:
ops = input[3]
else:
ops = '+'

# final result
if target_name not in formatted_inputs:
formatted_inputs[target_name] = []
format_inp = (key, val, ops, data_type)
formatted_inputs[target_name].append(format_inp)
return formatted_inputs


+ 0
- 1
brainpy/tools/__init__.py View File

@@ -3,4 +3,3 @@
from .ast2code import *
from .codes import *
from .dicts import *
from .functions import *

+ 3
- 2
brainpy/tools/ast2code.py View File

@@ -10,8 +10,9 @@ import ast
import sys
from contextlib import contextmanager


__all__ = [
'ast2code'
'ast2code',
]


@@ -314,7 +315,7 @@ class Transformer(ast.NodeVisitor):
self.write(')')
if getattr(node, 'returns', None):
self.write(' -> ')
self.visit(node.returns)
self.visit(node.return_intermediates)
self.write(':')
self.write_newline()



+ 68
- 516
brainpy/tools/codes.py View File

@@ -1,64 +1,26 @@
# -*- coding: utf-8 -*-

import ast
import inspect
import re
from types import LambdaType

from .ast2code import ast2code
from .dicts import DictPlus
from ..errors import CodeError
from ..errors import DiffEquationError

__all__ = [
'NoiseHandler',

'CodeLineFormatter',
'format_code',

'LineFormatterForTrajectory',
'format_code_for_trajectory',

'FindAtomicOp',
'find_atomic_op',

# replace function calls
'replace_func',
'FuncCallFinder',

# string processing
# tools for code string
'get_identifiers',
'get_main_code',
'get_line_indent',

'indent',
'deindent',
'word_replace',

# others
# other tools
'is_lambda_function',

#
'func_call',
'get_main_code',
'get_func_source',
]


def is_lambda_function(func):
"""Check whether the function is a ``lambda`` function. Comes from
https://stackoverflow.com/questions/23852423/how-to-check-that-variable-is-a-lambda-function

Parameters
----------
func : callable function
The function.

Returns
-------
bool
True of False.
"""
return isinstance(func, LambdaType) and func.__name__ == "<lambda>"
######################################
# String tools
######################################


def get_identifiers(expr, include_numbers=False):
@@ -102,465 +64,6 @@ def get_identifiers(expr, include_numbers=False):
return (identifiers - _ID_KEYWORDS) | numbers






class NoiseHandler(object):
normal_pattern = re.compile(r'(_normal_like_)\((\w+)\)')

@staticmethod
def vector_replace_f(m):
return 'numpy.random.normal(0., 1., ' + m.group(2) + '.shape)'

@staticmethod
def scalar_replace_f(m):
return 'numpy.random.normal(0., 1.)'

@staticmethod
def cuda_replace_f(m):
return 'xoroshiro128p_normal_float64(rng_states, _obj_i)'



class FuncCallFinder(ast.NodeTransformer):
""""""

def __init__(self, func_name):
self.name = func_name
self.args = []
self.kwargs = {}

def _get_attr_value(self, node, names):
if hasattr(node, 'value'):
names.insert(0, node.attr)
return self._get_attr_value(node.value, names)
else:
assert hasattr(node, 'id')
names.insert(0, node.id)
return names

def visit_Call(self, node):
if getattr(node, 'starargs', None) is not None:
raise ValueError("Variable number of arguments not supported")
if getattr(node, 'kwargs', None) is not None:
raise ValueError("Keyword arguments not supported")

if hasattr(node.func, 'id') and node.func.id == self.name:
for arg in node.args:
if isinstance(arg, ast.Name):
self.args.append(arg.id)
elif isinstance(arg, ast.Num):
self.args.append(arg.n)
else:
s = ast2code(ast.fix_missing_locations(arg))
self.args.append(s.strip())
for kv in node.keywords:
if isinstance(kv.value, ast.Name):
self.kwargs[kv.arg] = kv.value.id
elif isinstance(kv.value, ast.Num):
self.kwargs[kv.arg] = kv.value.n
else:
s = ast2code(ast.fix_missing_locations(kv.value))
self.kwargs[kv.arg] = s.strip()
return ast.Name(f'_{self.name}_res')
else:
args = [self.visit(arg) for arg in node.args]
keywords = [self.visit(kv) for kv in node.keywords]
return ast.Call(func=node.func, args=args, keywords=keywords)


def replace_func(code, func_name):
tree = ast.parse(code.strip())
w = FuncCallFinder(func_name)
tree = w.visit(tree)
tree = ast.fix_missing_locations(tree)
new_code = ast2code(tree)
return new_code, w.args, w.kwargs


def get_main_code(func):
"""Get the main function _code string.

For lambda function, return the

Parameters
----------
func : callable, Optional, int, float

Returns
-------

"""
if func is None:
return ''
elif callable(func):
if is_lambda_function(func):
func_code = get_func_source(func)
splits = func_code.split(':')
if len(splits) != 2:
raise ValueError(f'Can not parse function: \n{func_code}')
return f'return {splits[1]}'

else:
func_codes = inspect.getsourcelines(func)[0]
idx = 0
for i, line in enumerate(func_codes):
idx += 1
line = line.replace(' ', '')
if '):' in line:
break
else:
code = "\n".join(func_codes)
raise ValueError(f'Can not parse function: \n{code}')
return ''.join(func_codes[idx:])
else:
raise ValueError(f'Unknown function type: {type(func)}.')


def get_line_indent(line, spaces_per_tab=4):
line = line.replace('\t', ' ' * spaces_per_tab)
return len(line) - len(line.lstrip())


class FindAtomicOp(ast.NodeTransformer):
def __init__(self, var2idx):
self.var2idx = var2idx
self.left = None
self.right = None

def visit_Assign(self, node):
targets = node.targets
try:
assert len(targets) == 1
except AssertionError:
raise DiffEquationError('Do not support multiple assignment.')
left = ast2code(ast.fix_missing_locations(targets[0]))
key = targets[0].slice.value.s
value = targets[0].value.id
if node.value.__class__.__name__ == 'BinOp':
r_left = ast2code(ast.fix_missing_locations(node.value.left))
r_right = ast2code(ast.fix_missing_locations(node.value.right))
op = ast2code(ast.fix_missing_locations(node.value.op))
if op not in ['+', '-']:
# raise ValueError(f'Unsupported operation "{op}" for {left}.')
return node
self.left = f'{value}[{self.var2idx[key]}]'
if r_left == left:
if op == '+':
self.right = r_right
if op == '-':
self.right = f'- {r_right}'
elif r_left == '-' + left:
if op == '+':
self.right = f"2 * {left} + {r_right}"
if op == '-':
self.right = f"2 * {left} - {r_right}"
elif r_right == left:
if op == '+':
self.right = r_left
if op == '-':
self.right = f"{r_left} + 2 * {left}"
elif r_right == '-' + left:
if op == '+':
self.right = f"{r_left} + 2 * {left}"
if op == '-':
self.right = r_left
else:
return node
return node

def visit_AugAssign(self, node):
op = ast2code(ast.fix_missing_locations(node.op))
expr = ast2code(ast.fix_missing_locations(node.value))
if op not in ['+', '-']:
# left = ast2code(ast.fix_missing_locations(node.target))
# raise ValueError(f'Unsupported operation "{op}" for {left}.')
return node

key = node.target.slice.value.s
value = node.target.value.id

self.left = f'{value}[{self.var2idx[key]}]'
if op == '+':
self.right = expr
if op == '-':
self.right = f'- {expr}'

return node


def find_atomic_op(code_line, var2idx):
tree = ast.parse(code_line.strip())
formatter = FindAtomicOp(var2idx)
formatter.visit(tree)
return formatter


class CodeLineFormatter(ast.NodeTransformer):
def __init__(self):
self.lefts = []
self.rights = []
self.lines = []
self.scope = dict()

def visit_Assign(self, node, level=0):
targets = node.targets
try:
assert len(targets) == 1
except AssertionError:
raise DiffEquationError('Do not support multiple assignment.')
target = ast2code(ast.fix_missing_locations(targets[0]))
expr = ast2code(ast.fix_missing_locations(node.value))
prefix = ' ' * level
self.lefts.append(target)
self.rights.append(expr)
self.lines.append(f'{prefix}{target} = {expr}')
return node

def visit_AugAssign(self, node, level=0):
target = ast2code(ast.fix_missing_locations(node.target))
op = ast2code(ast.fix_missing_locations(node.op))
expr = ast2code(ast.fix_missing_locations(node.value))
prefix = ' ' * level
self.lefts.append(target)
self.rights.append(f"{target} {op} {expr}")
self.lines.append(f"{prefix}{target} {op}= {expr}")
return node

def visit_AnnAssign(self, node):
raise NotImplementedError('Do not support an assignment with a type annotation.')

def visit_node_not_assign(self, node, level=0):
prefix = ' ' * level
expr = ast2code(ast.fix_missing_locations(node))
self.lines.append(f'{prefix}{expr}')

def visit_Assert(self, node, level=0):
self.visit_node_not_assign(node, level)

def visit_Expr(self, node, level=0):
self.visit_node_not_assign(node, level)

def visit_Expression(self, node, level=0):
self.visit_node_not_assign(node, level)

def visit_content_in_condition_control(self, node, level):
if isinstance(node, ast.Expr):
self.visit_Expr(node, level)
elif isinstance(node, ast.Assert):
self.visit_Assert(node, level)
elif isinstance(node, ast.Assign):
self.visit_Assign(node, level)
elif isinstance(node, ast.AugAssign):
self.visit_AugAssign(node, level)
elif isinstance(node, ast.If):
self.visit_If(node, level)
elif isinstance(node, ast.For):
self.visit_For(node, level)
elif isinstance(node, ast.While):
self.visit_While(node, level)
else:
code = ast2code(ast.fix_missing_locations(node))
raise CodeError(f'BrainPy does not support {type(node)}.\n\n{code}')

def visit_If(self, node, level=0):
# If condition
prefix = ' ' * level
compare = ast2code(ast.fix_missing_locations(node.test))
self.lines.append(f'{prefix}if {compare}:')
# body
for expr in node.body:
self.visit_content_in_condition_control(expr, level + 1)

# elif
while node.orelse and len(node.orelse) == 1 and isinstance(node.orelse[0], ast.If):
node = node.orelse[0]
compare = ast2code(ast.fix_missing_locations(node.test))
self.lines.append(f'{prefix}elif {compare}:')
for expr in node.body:
self.visit_content_in_condition_control(expr, level + 1)

# else:
if len(node.orelse) > 0:
self.lines.append(f'{prefix}else:')
for expr in node.orelse:
self.visit_content_in_condition_control(expr, level + 1)

def visit_For(self, node, level=0):
prefix = ' ' * level
# target
target = ast2code(ast.fix_missing_locations(node.target))
# iter
iter = ast2code(ast.fix_missing_locations(node.iter))
self.lefts.append(target)
self.rights.append(iter)
self.lines.append(prefix + f'for {target} in {iter}:')
# body
for expr in node.body:
self.visit_content_in_condition_control(expr, level + 1)
# else
if len(node.orelse) > 0:
self.lines.append(prefix + 'else:')
for expr in node.orelse:
self.visit_content_in_condition_control(expr, level + 1)

def visit_While(self, node, level=0):
prefix = ' ' * level
# test
test = ast2code(ast.fix_missing_locations(node.test))
self.rights.append(test)
self.lines.append(prefix + f'while {test}:')
# body
for expr in node.body:
self.visit_content_in_condition_control(expr, level + 1)
# else
if len(node.orelse) > 0:
self.lines.append(prefix + 'else:')
for expr in node.orelse:
self.visit_content_in_condition_control(expr, level + 1)

def visit_Try(self, node):
raise CodeError('Do not support "try" handler.')

def visit_With(self, node):
raise CodeError('Do not support "with" block.')

def visit_Raise(self, node):
raise CodeError('Do not support "raise" statement.')

def visit_Delete(self, node):
raise CodeError('Do not support "del" operation.')


def format_code(code_string):
"""Get code lines from the string.

Parameters
----------
code_string

Returns
-------
code_lines : list
"""

tree = ast.parse(code_string.strip())
formatter = CodeLineFormatter()
formatter.visit(tree)
return formatter


class LineFormatterForTrajectory(CodeLineFormatter):
def __init__(self, fixed_vars):
super(LineFormatterForTrajectory, self).__init__()
self.fixed_vars = fixed_vars

def visit_Assign(self, node, level=0):
targets = node.targets
try:
assert len(targets) == 1
except AssertionError:
raise DiffEquationError(f'Do not support multiple assignment. \n'
f'Error in code line: \n\n'
f'{ast2code(ast.fix_missing_locations(node))}')
prefix = ' ' * level
target = targets[0]
append_lines = []

if isinstance(target, ast.Subscript):
if target.value.id == 'ST' and target.slice.value.s in self.fixed_vars:
left = ast2code(ast.fix_missing_locations(target))
self.lefts.append(left)
key = target.slice.value.s
self.lines.append(f'{prefix}{left} = _fixed_{key}')
self.scope[f'_fixed_{key}'] = self.fixed_vars[key]
return node

elif hasattr(target, 'elts'):
if len(target.elts) == 1:
elt = target.elts[0]
if isinstance(elt, ast.Subscript):
if elt.value.id == 'ST' and elt.slice.value.s in self.fixed_vars:
left = ast2code(ast.fix_missing_locations(elt))
self.lefts.append(left)
key = elt.slice.value.s
self.lines.append(f'{prefix}{left} = _fixed_{key}')
self.scope[f'_fixed_{key}'] = self.fixed_vars[key]
return node
left = ast2code(ast.fix_missing_locations(elt))
expr = ast2code(ast.fix_missing_locations(node.value))
self.lefts.append(left)
self.rights.append(expr)
self.lines.append(f'{prefix}{left} = {expr}')
return node
else:
for elt in target.elts:
if isinstance(elt, ast.Subscript):
if elt.value.id == 'ST' and elt.slice.value.s in self.fixed_vars:
left = ast2code(ast.fix_missing_locations(elt))
key = elt.slice.value.s
line = f'{prefix}{left} = _fixed_{key}'
self.scope[f'_fixed_{key}'] = self.fixed_vars[key]
append_lines.append(line)
left = ast2code(ast.fix_missing_locations(target))
expr = ast2code(ast.fix_missing_locations(node.value))
self.lefts.append(target)
self.rights.append(expr)
self.lines.append(f'{prefix}{left} = {expr}')
self.lines.extend(append_lines)
return node

left = ast2code(ast.fix_missing_locations(target))
expr = ast2code(ast.fix_missing_locations(node.value))
self.lefts.append(left)
self.rights.append(expr)
self.lines.append(f'{prefix}{left} = {expr}')
return node

def visit_AugAssign(self, node, level=0):
prefix = ' ' * level
if isinstance(node.target, ast.Subscript):
if node.target.value.id == 'ST' and node.target.slice.value.s in self.fixed_vars:
left = ast2code(ast.fix_missing_locations(node.target))
self.lefts.append(left)
key = node.target.slice.value.s
self.lines.append(f'{prefix}{left} = _fixed_{key}')
self.scope[f'_fixed_{key}'] = self.fixed_vars[key]
return node

op = ast2code(ast.fix_missing_locations(node.op))
left = ast2code(ast.fix_missing_locations(node.target))
expr = ast2code(ast.fix_missing_locations(node.value))
self.lefts.append(left)
self.rights.append(f"{left} {op} {expr}")
self.lines.append(f"{prefix}{left} {op}= {expr}")
return node


def format_code_for_trajectory(code_string, fixed_vars):
"""Get _code lines from the string.

Parameters
----------
code_string

Returns
-------
code_lines : list
"""

tree = ast.parse(code_string.strip())
formatter = LineFormatterForTrajectory(fixed_vars)
formatter.visit(tree)
return formatter


######################################
# String tools
######################################


def indent(text, num_tabs=1, spaces_per_tab=4, tab=None):
if tab is None:
tab = ' ' * spaces_per_tab
@@ -609,22 +112,31 @@ def word_replace(expr, substitutions):
banana*_b+c5+8+func(A)
"""
for var, replace_var in substitutions.items():
expr = re.sub(r'\b' + var + r'\b', str(replace_var), expr)
# expr = re.sub(r'\b' + var + r'\b', str(replace_var), expr)
expr = re.sub(r'\b(?<!\.)' + var + r'\b(?!\.)', str(replace_var), expr)
return expr


def func_call(args):
if isinstance(args, set):
args = sorted(list(args))
else:
assert isinstance(args, (tuple, list))
func_args = []
for i in range(0, len(args), 5):
for arg in args[i: i + 5]:
func_args.append(f'{arg},')
func_args.append('\n ')
return ' '.join(func_args).strip()
# return ', '.join(args).strip()
######################################
# Other tools
######################################


def is_lambda_function(func):
"""Check whether the function is a ``lambda`` function. Comes from
https://stackoverflow.com/questions/23852423/how-to-check-that-variable-is-a-lambda-function

Parameters
----------
func : callable function
The function.

Returns
-------
bool
True of False.
"""
return isinstance(func, LambdaType) and func.__name__ == "<lambda>"


def get_func_source(func):
@@ -636,3 +148,43 @@ def get_func_source(func):
except ValueError:
pass
return code


def get_main_code(func):
"""Get the main function _code string.

For lambda function, return the

Parameters
----------
func : callable, Optional, int, float

Returns
-------

"""
if func is None:
return ''
elif callable(func):
if is_lambda_function(func):
func_code = get_func_source(func)
splits = func_code.split(':')
if len(splits) != 2:
raise ValueError(f'Can not parse function: \n{func_code}')
return f'return {splits[1]}'

else:
func_codes = inspect.getsourcelines(func)[0]
idx = 0
for i, line in enumerate(func_codes):
idx += 1
line = line.replace(' ', '')
if '):' in line:
break
else:
code = "\n".join(func_codes)
raise ValueError(f'Can not parse function: \n{code}')
return ''.join(func_codes[idx:])
else:
raise ValueError(f'Unknown function type: {type(func)}.')


+ 2
- 1
brainpy/tools/dicts.py View File

@@ -2,13 +2,14 @@

import copy


__all__ = [
'DictPlus'
]


class DictPlus(dict):
"""Python dictionaries with advanced dot notation access.
"""Python dictionaries with tutorials_advanced dot notation access.

For example:



+ 0
- 125
brainpy/tools/functions.py View File

@@ -1,125 +0,0 @@
# -*- coding: utf-8 -*-

import functools
import inspect
import math
import types

import numba as nb
from numba import cuda
from numba.core.dispatcher import Dispatcher

from .codes import deindent
from .codes import get_func_source
from .. import backend
from .. import profile

__all__ = [
'get_cuda_size',
'jit',
'func_copy',
'numba_func',
'get_func_name',
]


def get_cuda_size(num):
if num < profile.get_num_thread_gpu():
num_block, num_thread = 1, num
else:
num_thread = profile.get_num_thread_gpu()
num_block = math.ceil(num / num_thread)
return num_block, num_thread


def get_func_name(func, replace=False):
func_name = func.__name__
if replace:
func_name = func_name.replace('_brainpy_delayed_', '')
return func_name


def jit(func=None):
"""JIT user defined functions.

Parameters
----------
func : callable, a_list, str
The function to be jit.

Returns
-------
jit_func : callable
function.
"""
if not isinstance(func, Dispatcher):
if not callable(func):
raise ValueError(f'"func" must be a callable function, but got "{type(func)}".')
op = profile.get_numba_profile()
func = nb.jit(func, **op)
return func


def func_copy(f):
"""Make a deepcopy of a python function.

This method is adopted from http://stackoverflow.com/a/6528148/190597 (Glenn Maynard).

Parameters
----------
f : callable
Function to copy.

Returns
-------
g : callable
Copied function.
"""
g = types.FunctionType(code=f.__code__,
globals=f.__globals__,
name=f.__name__,
argdefs=f.__defaults__,
closure=f.__closure__)
g = functools.update_wrapper(g, f)
g.__kwdefaults__ = f.__kwdefaults__
return g


def numba_func(func, params={}):
if isinstance(func, Dispatcher):
return func
if backend.func_in_numpy_or_math(func):
return func

vars = inspect.getclosurevars(func)
code_scope = dict(vars.nonlocals)
code_scope.update(vars.globals)

modified = False
# check scope variables
for k, v in code_scope.items():
# function
if callable(v):
if (not backend.func_in_numpy_or_math(v)) and (not isinstance(v, Dispatcher)):
# if v != np.func_by_name(v.__name__)
code_scope[k] = numba_func(v, params)
modified = True
# check scope changed parameters
for p, v in params.items():
if p in code_scope:
code_scope[p] = v
modified = True

if modified:
func_code = deindent(get_func_source(func))
exec(compile(func_code, '', "exec"), code_scope)
func = code_scope[func.__name__]
if profile.run_on_cpu():
return jit(func)
else:
return cuda.jit(device=True)(func)
else:
if profile.run_on_cpu():
return jit(func)
else:
return cuda.jit(device=True)(func)

+ 7
- 7
brainpy/visualization/figures.py View File

@@ -9,18 +9,18 @@ __all__ = [
]


def get_figure(n_row, n_col, len_row=3, len_col=6):
def get_figure(row_num, col_num, row_len=3, col_len=6):
"""Get the constrained_layout figure.

Parameters
----------
n_row : int
row_num : int
The row number of the figure.
n_col : int
col_num : int
The column number of the figure.
len_row : int, float
row_len : int, float
The length of each row.
len_col : int, float
col_len : int, float
The length of each column.

Returns
@@ -28,6 +28,6 @@ def get_figure(n_row, n_col, len_row=3, len_col=6):
fig_and_gs : tuple
Figure and GridSpec.
"""
fig = plt.figure(figsize=(n_col * len_col, n_row * len_row), constrained_layout=True)
gs = GridSpec(n_row, n_col, figure=fig)
fig = plt.figure(figsize=(col_num * col_len, row_num * row_len), constrained_layout=True)
gs = GridSpec(row_num, col_num, figure=fig)
return fig, gs

+ 4
- 4
brainpy/visualization/plots.py View File

@@ -5,8 +5,8 @@ import numpy as np
from matplotlib import animation
from matplotlib.gridspec import GridSpec

from .. import profile
from ..errors import ModelUseError
from brainpy import backend
from brainpy.errors import ModelUseError

__all__ = [
'line_plot',
@@ -234,7 +234,7 @@ def animate_2D(values,
figure : plt.figure
The created figure instance.
"""
dt = profile.get_dt() if dt is None else dt
dt = backend.get_dt() if dt is None else dt
num_step, num_neuron = values.shape
height, width = net_size
val_min = values.min() if val_min is None else val_min
@@ -336,7 +336,7 @@ def animate_1D(dynamical_vars,
"""

# check dt
dt = profile.get_dt() if dt is None else dt
dt = backend.get_dt() if dt is None else dt

# check figure
fig = plt.figure(figsize=(figsize or (6, 6)), constrained_layout=True)


+ 7
- 7
develop/benchmark/COBA/COBA.py View File

@@ -97,16 +97,16 @@ def run_brianpy(num_neu, duration, device='cpu'):
ST=bp.types.SynState([]),
mode='vector')

group = bp.NeuGroup(neuron, geometry=num_exc + num_inh)
group = bp.NeuGroup(neuron, size=num_exc + num_inh)
group.ST['V'] = np.random.randn(num_exc + num_inh) * 5. - 55.

exc_conn = bp.SynConn(exc_syn, pre_group=group[:num_exc],
post_group=group,
conn=bp.connect.FixedProb(prob=0.02))
exc_conn = bp.TwoEndConn(exc_syn, pre=group[:num_exc],
post=group,
conn=bp.connect.FixedProb(prob=0.02))

inh_conn = bp.SynConn(inh_syn, pre_group=group[num_exc:],
post_group=group,
conn=bp.connect.FixedProb(prob=0.02))
inh_conn = bp.TwoEndConn(inh_syn, pre=group[num_exc:],
post=group,
conn=bp.connect.FixedProb(prob=0.02))

net = bp.Network(group, exc_conn, inh_conn)



+ 86
- 93
develop/benchmark/COBA/COBA_brainpy.py View File

@@ -1,12 +1,12 @@
# -*- coding: utf-8 -*-
import time

import numpy as np

import time
import numpy as np
import brainpy as bp

dt = 0.05
bp.profile.set(jit=True, dt=dt)
bp.backend.set('numba', dt=dt)

# Parameters
num_exc = 3200
@@ -24,99 +24,92 @@ we = 0.6 # excitatory synaptic weight (voltage)
wi = 6.7 # inhibitory synaptic weight
ref = 5.0

neu_ST = bp.types.NeuState(
{'sp_t': -1e7,
'V': 0.,
'spike': 0.,
'ge': 0.,
'gi': 0.}
)


@bp.integrate
def int_ge(ge, t):
return - ge / taue


@bp.integrate
def int_gi(gi, t):
return - gi / taui


@bp.integrate
def int_V(V, t, ge, gi):
return (ge * (Erev_exc - V) + gi * (Erev_inh - V) + (El - V) + I) / taum


def neu_update(ST, _t):
ST['ge'] = int_ge(ST['ge'], _t)
ST['gi'] = int_gi(ST['gi'], _t)

ST['spike'] = 0.
if (_t - ST['sp_t']) > ref:
V = int_V(ST['V'], _t, ST['ge'], ST['gi'])
ST['spike'] = 0.
if V >= Vt:
ST['V'] = Vr
ST['spike'] = 1.
ST['sp_t'] = _t
else:
ST['V'] = V


neuron = bp.NeuType(name='COBA',
ST=neu_ST,
steps=neu_update,
mode='scalar')


def update1(pre, post, pre2post):
for pre_id in range(len(pre2post)):
if pre['spike'][pre_id] > 0.:
post_ids = pre2post[pre_id]
for i in post_ids:
post['ge'][i] += we


exc_syn = bp.SynType('exc_syn',
steps=update1,
ST=bp.types.SynState([]),
mode='vector')


def update2(pre, post, pre2post):
for pre_id in range(len(pre2post)):
if pre['spike'][pre_id] > 0.:
post_ids = pre2post[pre_id]
for i in post_ids:
post['gi'][i] += wi


inh_syn = bp.SynType('inh_syn',
steps=update2,
ST=bp.types.SynState([]),
mode='vector')


group = bp.NeuGroup(neuron,
geometry=num_exc + num_inh,
monitors=['spike'])
group.ST['V'] = np.random.randn(num_exc + num_inh) * 5. - 55.

exc_conn = bp.SynConn(exc_syn,
pre_group=group[:num_exc],
post_group=group,
conn=bp.connect.FixedProb(prob=0.02))

inh_conn = bp.SynConn(inh_syn,
pre_group=group[num_exc:],
post_group=group,
conn=bp.connect.FixedProb(prob=0.02))

net = bp.Network(group, exc_conn, inh_conn)
class LIF(bp.NeuGroup):
target_backend = ['numpy', 'numba']

def __init__(self, size, **kwargs):
# variables
self.V = bp.backend.zeros(size)
self.spike = bp.backend.zeros(size)
self.ge = bp.backend.zeros(size)
self.gi = bp.backend.zeros(size)
self.input = bp.backend.zeros(size)
self.t_last_spike = bp.backend.ones(size) * -1e7

super(LIF, self).__init__(size=size, **kwargs)

@staticmethod
@bp.odeint(method='euler')
def int_g(ge, gi, t):
dge = - ge / taue
dgi = - gi / taui
return dge, dgi

@staticmethod
@bp.odeint(method='euler')
def int_V(V, t, ge, gi):
dV = (ge * (Erev_exc - V) + gi * (Erev_inh - V) + El - V + I) / taum
return dV

def update(self, _t):
self.ge, self.gi = self.int_g(self.ge, self.gi, _t)
for i in range(self.size[0]):
self.spike[i] = 0.
if (_t - self.t_last_spike[i]) > ref:
V = self.int_V(self.V[i], _t, self.ge[i], self.gi[i])
if V >= Vt:
self.V[i] = Vr
self.spike[i] = 1.
self.t_last_spike[i] = _t
else:
self.V[i] = V
self.input[i] = I


class ExcSyn(bp.TwoEndConn):
target_backend = ['numpy', 'numba']

def __init__(self, pre, post, conn, **kwargs):
self.conn = conn(pre.size, post.size)
self.pre2post = self.conn.requires('pre2post')
super(ExcSyn, self).__init__(pre=pre, post=post, **kwargs)

def update(self, _t):
for pre_id, spike in enumerate(self.pre.spike):
if spike > 0:
for post_i in self.pre2post[pre_id]:
self.post.ge[post_i] += we


class InhSyn(bp.TwoEndConn):
target_backend = ['numpy', 'numba']

def __init__(self, pre, post, conn, **kwargs):
self.conn = conn(pre.size, post.size)
self.pre2post = self.conn.requires('pre2post')
super(InhSyn, self).__init__(pre=pre, post=post, **kwargs)

def update(self, _t):
for pre_id, spike in enumerate(self.pre.spike):
if spike > 0:
for post_i in self.pre2post[pre_id]:
self.post.gi[post_i] += wi


E_group = LIF(num_exc, monitors=['spike'])
E_group.V = np.random.randn(num_exc) * 5. - 55.
I_group = LIF(num_inh, monitors=['spike'])
I_group.V = np.random.randn(num_inh) * 5. - 55.
E2E = ExcSyn(pre=E_group, post=E_group, conn=bp.connect.FixedProb(0.02))
E2I = ExcSyn(pre=E_group, post=I_group, conn=bp.connect.FixedProb(0.02))
I2E = InhSyn(pre=I_group, post=E_group, conn=bp.connect.FixedProb(0.02))
I2I = InhSyn(pre=I_group, post=I_group, conn=bp.connect.FixedProb(0.02))

net = bp.Network(E_group, I_group, E2E, E2I, I2E, I2I)
t0 = time.time()

net.run(5000., report=True)
print('Used time {} s.'.format(time.time() - t0))

bp.visualize.raster_plot(net.ts, group.mon.spike, show=True)
bp.visualize.raster_plot(net.ts, E_group.mon.spike, show=True)

+ 5
- 5
develop/benchmark/COBAHH/COBAHH_brainpy.py View File

@@ -117,16 +117,16 @@ exc_syn = bp.SynType('exc_syn', steps=exc_update, ST=bp.types.SynState())

inh_syn = bp.SynType('inh_syn', steps=inh_update, ST=bp.types.SynState())

group = bp.NeuGroup(neuron, geometry=num_exc + num_inh, monitors=['sp'])
group = bp.NeuGroup(neuron, size=num_exc + num_inh, monitors=['sp'])
group.ST['V'] = El + (np.random.randn(num_exc + num_inh) * 5 - 5)
group.ST['ge'] = (np.random.randn(num_exc + num_inh) * 1.5 + 4) * 10.
group.ST['gi'] = (np.random.randn(num_exc + num_inh) * 12 + 20) * 10.

exc_conn = bp.SynConn(exc_syn, pre_group=group[:num_exc], post_group=group,
conn=bp.connect.FixedProb(prob=0.02))
exc_conn = bp.TwoEndConn(exc_syn, pre=group[:num_exc], post=group,
conn=bp.connect.FixedProb(prob=0.02))

inh_conn = bp.SynConn(inh_syn, pre_group=group[num_exc:], post_group=group,
conn=bp.connect.FixedProb(prob=0.02))
inh_conn = bp.TwoEndConn(inh_syn, pre=group[num_exc:], post=group,
conn=bp.connect.FixedProb(prob=0.02))

net = bp.Network(group, exc_conn, inh_conn)
t0 = time.time()


+ 9
- 9
develop/benchmark/CUBA/CUBA_brainpy.py View File

@@ -87,19 +87,19 @@ def update2(pre, post, pre2post):
inh_syn = bp.SynType('inh_syn', steps=update2, ST=bp.types.SynState())

group = bp.NeuGroup(neuron,
geometry=num_exc + num_inh,
size=num_exc + num_inh,
monitors=['sp'])
group.ST['V'] = Vr + np.random.rand(num_exc + num_inh) * (Vt - Vr)

exc_conn = bp.SynConn(exc_syn,
pre_group=group[:num_exc],
post_group=group,
conn=bp.connect.FixedProb(prob=0.02))
exc_conn = bp.TwoEndConn(exc_syn,
pre=group[:num_exc],
post=group,
conn=bp.connect.FixedProb(prob=0.02))

inh_conn = bp.SynConn(inh_syn,
pre_group=group[num_exc:],
post_group=group,
conn=bp.connect.FixedProb(prob=0.02))
inh_conn = bp.TwoEndConn(inh_syn,
pre=group[num_exc:],
post=group,
conn=bp.connect.FixedProb(prob=0.02))

net = bp.Network(group, exc_conn, inh_conn, mode='repeat')
t0 = time.time()


+ 23
- 20
develop/benchmark/scaling_test.py View File

@@ -8,6 +8,7 @@ Test the network scaling ability.
import time
import brainpy as bp
import numpy as np
import math


def define_hh(E_Na=50., g_Na=120., E_K=-77., g_K=36., E_Leak=-54.387,
@@ -25,20 +26,20 @@ def define_hh(E_Na=50., g_Na=120., E_K=-77., g_K=36., E_Leak=-54.387,

@bp.integrate
def int_m(m, t, V):
alpha = 0.1 * (V + 40) / (1 - np.exp(-(V + 40) / 10))
beta = 4.0 * np.exp(-(V + 65) / 18)
alpha = 0.1 * (V + 40) / (1 - math.exp(-(V + 40) / 10))
beta = 4.0 * math.exp(-(V + 65) / 18)
return alpha * (1 - m) - beta * m

@bp.integrate
def int_h(h, t, V):
alpha = 0.07 * np.exp(-(V + 65) / 20.)
beta = 1 / (1 + np.exp(-(V + 35) / 10))
alpha = 0.07 * math.exp(-(V + 65) / 20.)
beta = 1 / (1 + math.exp(-(V + 35) / 10))
return alpha * (1 - h) - beta * h

@bp.integrate
def int_n(n, t, V):
alpha = 0.01 * (V + 55) / (1 - np.exp(-(V + 55) / 10))
beta = 0.125 * np.exp(-(V + 65) / 80)
alpha = 0.01 * (V + 55) / (1 - math.exp(-(V + 55) / 10))
beta = 0.125 * math.exp(-(V + 65) / 80)
return alpha * (1 - n) - beta * n

@bp.integrate
@@ -77,7 +78,7 @@ def hh_compare_cpu_and_multi_cpu(num=1000, vector=True):

HH = define_hh()
HH.mode = 'vector' if vector else 'scalar'
neu = bp.NeuGroup(HH, geometry=num)
neu = bp.NeuGroup(HH, size=num)

t0 = time.time()
neu.run(duration=1000., report=True)
@@ -86,7 +87,7 @@ def hh_compare_cpu_and_multi_cpu(num=1000, vector=True):

print(f'HH, vector_based={vector}, device=multi-cpu', end=', ')
bp.profile.set(jit=True, device='multi-cpu')
neu = bp.NeuGroup(HH, geometry=num)
neu = bp.NeuGroup(HH, size=num)
t0 = time.time()
neu.run(duration=1000., report=True)
t_multi_cpu = time.time() - t0
@@ -98,27 +99,27 @@ def hh_compare_cpu_and_multi_cpu(num=1000, vector=True):

def hh_compare_cpu_and_gpu(num=1000):
print(f'HH, device=cpu', end=', ')
bp.profile.set(jit=True, device='cpu', show_code=True)
bp.profile.set(jit=True, device='cpu', show_code=False)

HH = define_hh()
HH.mode = 'scalar'
neu = bp.NeuGroup(HH, geometry=num)
t0 = time.time()
neu.run(duration=1000., report=True)
t_cpu = time.time() - t0
print('used {:.3f} ms'.format(t_cpu))
# neu = bp.NeuGroup(HH, geometry=num)
#
# t0 = time.time()
# neu.run(duration=1000., report=True)
# t_cpu = time.time() - t0
# print('used {:.3f} ms'.format(t_cpu))

print(f'HH, device=gpu', end=', ')
bp.profile.set(jit=True, device='gpu')
neu = bp.NeuGroup(HH, geometry=num)
neu = bp.NeuGroup(HH, size=num)
t0 = time.time()
neu.run(duration=1000., report=True)
t_multi_cpu = time.time() - t0
print('used {:.3f} ms'.format(t_multi_cpu))

print(f"HH model with multi-cpu speeds up {t_cpu / t_multi_cpu}")
print()
# print(f"HH model with multi-cpu speeds up {t_cpu / t_multi_cpu}")
# print()


if __name__ == '__main__':
@@ -128,9 +129,11 @@ if __name__ == '__main__':
# hh_compare_cpu_and_multi_cpu(int(1e5))
# hh_compare_cpu_and_multi_cpu(int(1e6))

# hh_compare_cpu_and_gpu(int(1e2))
# hh_compare_cpu_and_gpu(int(1e3))
# hh_compare_cpu_and_gpu(int(1e4))
# hh_compare_cpu_and_gpu(int(1e5))
# hh_compare_cpu_and_gpu(int(1e6))
hh_compare_cpu_and_gpu(int(1e5))
hh_compare_cpu_and_gpu(int(1e6))
# hh_compare_cpu_and_gpu(int(1e7))




+ 1
- 2
develop/conda-recipe/meta.yaml View File

@@ -1,6 +1,6 @@
package:
name: brainpy-simulator
version: "0.3.5"
version: "1.0.0-alpha"

source:
path: ../../
@@ -18,7 +18,6 @@ requirements:
- python
- numpy>=1.13
- sympy>=1.2
- scipy>=1.2.0
- numba>=0.50
- matplotlib>=3.0
- setuptools>=40.0.0


+ 1
- 1
docs/Makefile View File

@@ -5,7 +5,7 @@
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SPHINXPROJ = npbrain
SPHINXPROJ = brainpy
SOURCEDIR = .
BUILDDIR = _build



+ 0
- 291
docs/advanced/HH_model_in_ANNarchy.ipynb
File diff suppressed because it is too large
View File


+ 0
- 0
docs/advanced/Limitations.rst View File


+ 0
- 393
docs/advanced/debugging.ipynb View File

@@ -1,393 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Debugging"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Even if you write clear and readable code, even if you fully understand your codes, env if you are very familiar with your model, weird bugs will inevitably appear and you will need to debug them in some way. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Fortunately, ``BrainPy`` supports debugging with [pdb](https://docs.python.org/3/library/pdb.html)\n",
"module or [breakpoint](https://docs.python.org/3/library/functions.html#breakpoint) (The latest version of \n",
"BrainPy removes the support of debugging in IDEs). That is to say, you do not need to resort to using bunch \n",
"of `print` statements to see what's happening in their code. On the contrary, you can work with \n",
"Python’s interactive source code debugger to see the state of any variable in your model."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For the variables you are interested in, you just need to add the ``pdb.set_trace()`` or ``breakpoint()`` after \n",
"the code line. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this section, let's take the HH neuron model as an example to illustrate how to debug your\n",
"model within BrainPy."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2020-12-29T09:15:31.217240Z",
"start_time": "2020-12-29T09:15:28.875338Z"
}
},
"outputs": [],
"source": [
"import brainpy as bp\n",
"import numpy as np\n",
"import pdb"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you want to debug your model, we would like to recommond you to open the ``show_code=True``."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2020-12-29T09:15:31.233192Z",
"start_time": "2020-12-29T09:15:31.220227Z"
}
},
"outputs": [],
"source": [
"bp.profile.set(show_code=True, jit=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here, the HH neuron model is defined as:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2020-12-29T09:15:31.332935Z",
"start_time": "2020-12-29T09:15:31.243166Z"
}
},
"outputs": [],
"source": [
"E_Na = 50.\n",
"E_K = -77.\n",
"E_leak = -54.387\n",
"C = 1.0\n",
"g_Na = 120.\n",
"g_K = 36.\n",
"g_leak = 0.03\n",
"V_th = 20.\n",
"noise = 1.\n",
"\n",
"ST = bp.types.NeuState(\n",
" {'V': -65., 'm': 0.05, 'h': 0.60,\n",
" 'n': 0.32, 'spike': 0., 'input': 0.}\n",
")\n",
"\n",
"@bp.integrate\n",
"def int_m(m, _t, V):\n",
" alpha = 0.1 * (V + 40) / (1 - np.exp(-(V + 40) / 10))\n",
" beta = 4.0 * np.exp(-(V + 65) / 18)\n",
" return alpha * (1 - m) - beta * m\n",
"\n",
"@bp.integrate\n",
"def int_h(h, _t, V):\n",
" alpha = 0.07 * np.exp(-(V + 65) / 20.)\n",
" beta = 1 / (1 + np.exp(-(V + 35) / 10))\n",
" return alpha * (1 - h) - beta * h\n",
"\n",
"@bp.integrate\n",
"def int_n(n, _t, V):\n",
" alpha = 0.01 * (V + 55) / (1 - np.exp(-(V + 55) / 10))\n",
" beta = 0.125 * np.exp(-(V + 65) / 80)\n",
" return alpha * (1 - n) - beta * n\n",
"\n",
"@bp.integrate\n",
"def int_V(V, _t, m, h, n, I_ext):\n",
" I_Na = (g_Na * np.power(m, 3.0) * h) * (V - E_Na)\n",
" I_K = (g_K * np.power(n, 4.0))* (V - E_K)\n",
" I_leak = g_leak * (V - E_leak)\n",
" dVdt = (- I_Na - I_K - I_leak + I_ext)/C\n",
" return dVdt, noise / C\n",
"\n",
"def update(ST, _t):\n",
" m = np.clip(int_m(ST['m'], _t, ST['V']), 0., 1.)\n",
" h = np.clip(int_h(ST['h'], _t, ST['V']), 0., 1.)\n",
" n = np.clip(int_n(ST['n'], _t, ST['V']), 0., 1.)\n",
" V = int_V(ST['V'], _t, m, h, n, ST['input'])\n",
" \n",
" pdb.set_trace()\n",
" \n",
" spike = np.logical_and(ST['V'] < V_th, V >= V_th)\n",
" ST['spike'] = spike\n",
" ST['V'] = V\n",
" ST['m'] = m\n",
" ST['h'] = h\n",
" ST['n'] = n\n",
" ST['input'] = 0.\n",
"\n",
"HH = bp.NeuType(ST=ST,\n",
" name='HH_neuron',\n",
" steps=update,\n",
" mode='vector')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this example, we add ``pdb.set_trace()`` after the variables $m$, $h$, $n$ and $V$. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then we can create a neuron group, and try to run this neuron model:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2020-12-29T09:16:43.061109Z",
"start_time": "2020-12-29T09:15:31.334919Z"
},
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"def NeuGroup0_input_step(ST, input_inp,):\n",
" # \"input\" step function of NeuGroup0\n",
" ST[5] += input_inp\n",
" \n",
"\n",
"\n",
"def NeuGroup0_monitor_step(ST, _i, mon_ST_spike,):\n",
" # \"monitor\" step function of NeuGroup0\n",
" mon_ST_spike[_i] = ST[4]\n",
" \n",
"\n",
"\n",
"def NeuGroup0_update(ST, _t,):\n",
" # \"update\" step function of NeuGroup0\n",
" _int_m_m = ST[1]\n",
" _int_m__t = _t\n",
" _int_m_V = ST[0]\n",
" _int_m_alpha = 0.1 * (_int_m_V + 40) / (1 - np.exp(-(_int_m_V + 40) / 10))\n",
" _int_m_beta = 4.0 * np.exp(-(_int_m_V + 65) / 18)\n",
" _dfm_dt = _int_m_alpha * (1 - _int_m_m) - _int_m_beta * _int_m_m\n",
" _int_m_m = 0.1*_dfm_dt + _int_m_m\n",
" _int_m_res = _int_m_m\n",
" m = np.clip(_int_m_res, 0.0, 1.0)\n",
" \n",
" _int_h_h = ST[2]\n",
" _int_h__t = _t\n",
" _int_h_V = ST[0]\n",
" _int_h_alpha = 0.07 * np.exp(-(_int_h_V + 65) / 20.0)\n",
" _int_h_beta = 1 / (1 + np.exp(-(_int_h_V + 35) / 10))\n",
" _dfh_dt = _int_h_alpha * (1 - _int_h_h) - _int_h_beta * _int_h_h\n",
" _int_h_h = 0.1*_dfh_dt + _int_h_h\n",
" _int_h_res = _int_h_h\n",
" h = np.clip(_int_h_res, 0.0, 1.0)\n",
" \n",
" _int_n_n = ST[3]\n",
" _int_n__t = _t\n",
" _int_n_V = ST[0]\n",
" _int_n_alpha = 0.01 * (_int_n_V + 55) / (1 - np.exp(-(_int_n_V + 55) / 10))\n",
" _int_n_beta = 0.125 * np.exp(-(_int_n_V + 65) / 80)\n",
" _dfn_dt = _int_n_alpha * (1 - _int_n_n) - _int_n_beta * _int_n_n\n",
" _int_n_n = 0.1*_dfn_dt + _int_n_n\n",
" _int_n_res = _int_n_n\n",
" n = np.clip(_int_n_res, 0.0, 1.0)\n",
" \n",
" _int_V_V = ST[0]\n",
" _int_V__t = _t\n",
" _int_V_m = m\n",
" _int_V_h = h\n",
" _int_V_n = n\n",
" _int_V_I_ext = ST[5]\n",
" _int_V_I_Na = g_Na * np.power(_int_V_m, 3.0) * _int_V_h * (_int_V_V - E_Na)\n",
" _int_V_I_K = g_K * np.power(_int_V_n, 4.0) * (_int_V_V - E_K)\n",
" _int_V_I_leak = g_leak * (_int_V_V - E_leak)\n",
" _int_V_dVdt = (-_int_V_I_Na - _int_V_I_K - _int_V_I_leak + _int_V_I_ext) / C\n",
" _dfV_dt = _int_V_dVdt\n",
" _V_dW = _normal_like(_int_V_V)\n",
" _dgV_dt = noise / C\n",
" _int_V_V = _int_V_V + 0.316227766016838*_V_dW*_dgV_dt + 0.1*_dfV_dt\n",
" _int_V_res = _int_V_V\n",
" V = _int_V_res\n",
" \n",
" pdb.set_trace()\n",
" \n",
" spike = np.logical_and(ST[0] < V_th, V >= V_th)\n",
" ST[4] = spike\n",
" ST[0] = V\n",
" ST[1] = m\n",
" ST[2] = h\n",
" ST[3] = n\n",
" ST[5] = 0.0\n",
" \n",
"\n",
"\n",
"def step_func(_t, _i, _dt):\n",
" NeuGroup0_runner.input_step(NeuGroup0.ST[\"_data\"], NeuGroup0_runner.input_inp,)\n",
" NeuGroup0_runner.update(NeuGroup0.ST[\"_data\"], _t,)\n",
" NeuGroup0_runner.monitor_step(NeuGroup0.ST[\"_data\"], _i, NeuGroup0.mon[\"spike\"],)\n",
"\n",
"> \u001b[1;32mc:\\users\\oujag\\codes\\projects\\brainpy\\docs\\advanced\u001b[0m(52)\u001b[0;36mupdate\u001b[1;34m()\u001b[0m\n",
"\n",
"ipdb> p m\n",
"array([0.05123855])\n",
"ipdb> p n\n",
"array([0.31995744])\n",
"ipdb> p h\n",
"array([0.59995445])\n",
"ipdb> n\n",
"> \u001b[1;32mc:\\users\\oujag\\codes\\projects\\brainpy\\docs\\advanced\u001b[0m(53)\u001b[0;36mupdate\u001b[1;34m()\u001b[0m\n",
"\n",
"ipdb> n\n",
"> \u001b[1;32mc:\\users\\oujag\\codes\\projects\\brainpy\\docs\\advanced\u001b[0m(54)\u001b[0;36mupdate\u001b[1;34m()\u001b[0m\n",
"\n",
"ipdb> n\n",
"> \u001b[1;32mc:\\users\\oujag\\codes\\projects\\brainpy\\docs\\advanced\u001b[0m(55)\u001b[0;36mupdate\u001b[1;34m()\u001b[0m\n",
"\n",
"ipdb> n\n",
"> \u001b[1;32mc:\\users\\oujag\\codes\\projects\\brainpy\\docs\\advanced\u001b[0m(56)\u001b[0;36mupdate\u001b[1;34m()\u001b[0m\n",
"\n",
"ipdb> n\n",
"> \u001b[1;32mc:\\users\\oujag\\codes\\projects\\brainpy\\docs\\advanced\u001b[0m(57)\u001b[0;36mupdate\u001b[1;34m()\u001b[0m\n",
"\n",
"ipdb> n\n",
"> \u001b[1;32mc:\\users\\oujag\\codes\\projects\\brainpy\\docs\\advanced\u001b[0m(58)\u001b[0;36mupdate\u001b[1;34m()\u001b[0m\n",
"\n",
"ipdb> p ST\n",
"array([[-6.41827214e+01],\n",
" [ 5.12385538e-02],\n",
" [ 5.99954448e-01],\n",
" [ 3.19957442e-01],\n",
" [ 0.00000000e+00],\n",
" [ 1.00000000e+01]])\n",
"ipdb> q\n"
]
},
{
"ename": "BdbQuit",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mBdbQuit\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-4-703915fa7809>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[0mgroup\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mbp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mNeuGroup\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mHH\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgeometry\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmonitors\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'spike'\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[0mgroup\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m1000.\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'input'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m10.\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[1;32m~\\codes\\projects\\BrainPy\\brainpy\\core\\base.py\u001b[0m in \u001b[0;36mrun\u001b[1;34m(self, duration, inputs, report, report_percent)\u001b[0m\n\u001b[0;32m 584\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 585\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mrun_idx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrun_length\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 586\u001b[1;33m \u001b[0mstep_func\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_t\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtimes\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mrun_idx\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0m_i\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mrun_idx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0m_dt\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdt\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 587\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 588\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mprofile\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun_on_gpu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m?\u001b[0m in \u001b[0;36mstep_func\u001b[1;34m(_t, _i, _dt)\u001b[0m\n",
"\u001b[1;32m?\u001b[0m in \u001b[0;36mupdate\u001b[1;34m(ST, _t)\u001b[0m\n",
"\u001b[1;32m?\u001b[0m in \u001b[0;36mupdate\u001b[1;34m(ST, _t)\u001b[0m\n",
"\u001b[1;32m~\\Miniconda3\\envs\\py38\\lib\\bdb.py\u001b[0m in \u001b[0;36mtrace_dispatch\u001b[1;34m(self, frame, event, arg)\u001b[0m\n\u001b[0;32m 86\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[1;31m# None\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 87\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mevent\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m'line'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 88\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdispatch_line\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 89\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mevent\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m'call'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 90\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdispatch_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0marg\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\Miniconda3\\envs\\py38\\lib\\bdb.py\u001b[0m in \u001b[0;36mdispatch_line\u001b[1;34m(self, frame)\u001b[0m\n\u001b[0;32m 111\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstop_here\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbreak_here\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 112\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0muser_line\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 113\u001b[1;33m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mquitting\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;32mraise\u001b[0m \u001b[0mBdbQuit\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 114\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrace_dispatch\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 115\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mBdbQuit\u001b[0m: "
]
}
],
"source": [
"group = bp.NeuGroup(HH, geometry=1, monitors=['spike'])\n",
"group.run(1000., inputs=('input', 10.))"
]
}
],
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": false,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 4
}

+ 0
- 232
docs/advanced/differential_equations.ipynb View File

@@ -1,232 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Differential Equations"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import brainpy as bp"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In BrainPy, the difinition of differential equations is supportted by a powerfull decorator ``@bp.integrate``. Users should only explicitly write out the right hand of the differential equations, and BrainPy will automatically integerates your defined differential equations. \n",
"\n",
"BrainPy supports the numerical integration of ordinary differential equations (ODEs) and stochastic differential equations (SDEs). "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ODEs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"For an ordinary differential equation\n",
"\n",
"$$\n",
"\\frac{dx}{dt} = f(x, t)\n",
"$$\n",
"\n",
"the coding in BrainPy has a general form of:\n",
"\n",
"```python\n",
"\n",
"@bp.integrate\n",
"def func(x, t, other_arguments):\n",
" # ... do some computation\n",
" f = ...\n",
" return f\n",
"\n",
"x_t_plus = func(x_t, t, other_arguments)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## SDEs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"For the stochastic differential equation:\n",
"\n",
"$$\n",
"\\frac{dx}{dt} = f(x, t) + g(x, t) dW\n",
"$$\n",
"\n",
"the coding in BrainPy can be conducted as:\n",
"\n",
"```python\n",
"\n",
"@bp.integrate\n",
"def func(x, t, other_arguments):\n",
" # ... do some computation\n",
" f = ...\n",
" g = ...\n",
" return f, g\n",
"\n",
"x_t_plus = func(x_t, t, other_arguments)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"ExecuteTime": {
"end_time": "2020-11-08T13:31:02.326207Z",
"start_time": "2020-11-08T13:31:02.147805Z"
}
},
"source": [
"## Return intermediate values"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"BrainPy also supports the user return the intermediate computed results. Let's take the differential equation of $V$ in Hodgkin–Huxley (HH) neuron model as an example. In HH model, the stochastic differential equation $V$ is expressed as:\n",
"\n",
"\\begin{align}\n",
"C_{m}{\\frac {d V}{dt}}&=-{\\bar {g}}_{K}n^{4}(V-V_{K}) - {\\bar {g}}_{Na}m^{3}h(V-V_{Na}) -{\\bar {g}}_{l}(V-V_{l}) + I^{ext} + I^{noise} * dW, \n",
"\\end{align}\n",
"\n",
"where \n",
"\n",
"- the potassium channel current is $I_{K} = {\\bar {g}}_{K}n^{4}(V-V_{K})$, \n",
"- the sodium channel current is $I_{Na} = {\\bar {g}}_{Na}m^{3}h(V-V_{Na})$, and \n",
"- the leaky current is $I_{L} = {\\bar {g}}_{l}(V-V_{l})$.\n",
"\n",
"The user may not only has the interest of the final value $V$, but also take care of the intermediate value $I_{Na}$, $I_K$ and $I_L$. In BrainPy, this kind of requirement can be coded as:\n",
"\n",
"```python\n",
"\n",
"@bp.integrate\n",
"def func(V, t, m, h, n, Iext):\n",
" INa = gNa * m ** 3 * h * (V - ENa)\n",
" IK = gK * n ** 4 * (V - EK)\n",
" IL = gLeak * (V - ELeak)\n",
" f = (- INa - IK - IL + Isyn) / C\n",
" g = noise / C\n",
" return (f, g), INa, IK, IL\n",
"\n",
"V_t_plus, INa, IK, IL = func(V_t, t, m, h, n, Iext)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"Generally, return intermediate values in ODE function can be coded as:\n",
"\n",
"```python\n",
"\n",
"@bp.integrate\n",
"def func(x, t, other_arguments):\n",
" # ... do some computation\n",
" f = ...\n",
" return (f, ), some_values\n",
"```\n",
"\n",
"Return intermediate values in SDE function can be coded as:\n",
"\n",
"```python\n",
"\n",
"@bp.integrate\n",
"def func(x, t, other_arguments):\n",
" # ... do some computation\n",
" f = ...\n",
" g = ...\n",
" return (f, g), some_values\n",
"```"
]
}
],
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": false,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 4
}

+ 0
- 266
docs/advanced/gapjunction_lif_in_brian2.ipynb
File diff suppressed because it is too large
View File


+ 4
- 4
docs/apis/analysis.rst View File

@@ -1,5 +1,5 @@
brainpy.analysis package
========================
brainpy.analysis
================

.. currentmodule:: brainpy.analysis
.. automodule:: brainpy.analysis
@@ -39,10 +39,10 @@ brainpy.analysis package
.. autoclass:: PhasePlane
:members: plot_fixed_point, plot_nullcline, plot_trajectory, plot_vector_field

.. autoclass:: PhasePlane1D
.. autoclass:: _PhasePlane1D
:members: plot_fixed_point, plot_nullcline, plot_trajectory, plot_vector_field

.. autoclass:: PhasePlane2D
.. autoclass:: _PhasePlane2D
:members: plot_fixed_point, plot_nullcline, plot_trajectory, plot_vector_field




+ 18
- 0
docs/apis/backend.rst View File

@@ -0,0 +1,18 @@
brainpy.backend
===============

.. currentmodule:: brainpy.backend
.. automodule:: brainpy.backend

.. autosummary::
:toctree: _autosummary

set
set_dt
get_dt
set_ops
get_backend
set_class_keywords
set_ops_from_module



+ 2
- 2
docs/apis/connectivity.rst View File

@@ -1,5 +1,5 @@
brainpy.connect package
============================
brainpy.connect
===============

.. currentmodule:: brainpy.connectivity
.. automodule:: brainpy.connectivity


+ 0
- 43
docs/apis/core.rst View File

@@ -1,43 +0,0 @@
brainpy.core package
====================

.. currentmodule:: brainpy.core
.. automodule:: brainpy.core

.. autosummary::
:toctree: _autosummary

ObjType
NeuType
SynType
Ensemble
NeuGroup
SynConn
Network
ParsUpdate
delayed

.. autoclass:: ObjType
:members:

.. autoclass:: NeuType
:members:

.. autoclass:: SynType
:members:

.. autoclass:: Ensemble
:members:

.. autoclass:: NeuGroup
:members:

.. autoclass:: SynConn
:members:

.. autoclass:: Network
:members: add, build, run

.. autoclass:: ParsUpdate
:members: get, keys, items


+ 5
- 5
docs/apis/errors.rst View File

@@ -1,5 +1,5 @@
brainpy.errors package
============================
brainpy.errors
==============

.. currentmodule:: brainpy.errors
.. automodule:: brainpy.errors
@@ -9,7 +9,7 @@ brainpy.errors package

ModelDefError
ModelUseError
TypeMismatchError
IntegratorError
DiffEquationError
DiffEqError
CodeError
AnalyzerError
PackageMissingError

+ 2
- 19
docs/apis/inputs.rst View File

@@ -1,5 +1,5 @@
brainpy.inputs package
============================
brainpy.inputs
==============

.. currentmodule:: brainpy.inputs
.. automodule:: brainpy.inputs
@@ -10,21 +10,4 @@ brainpy.inputs package
constant_current
spike_current
ramp_current
PoissonInput
SpikeTimeInput
FreqInput


.. autoclass:: PoissonInput
:toctree:
:members:

.. autoclass:: SpikeTimeInput
:toctree:
:members:

.. autoclass:: FreqInput
:toctree:
:members:



Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save