Are you sure you want to delete this task? Once this task is deleted, it cannot be recovered.
Yanqi-Chen c1031c34bf | 10 months ago | |
---|---|---|
.. | ||
__init__.py | 1 year ago | |
base.py | 11 months ago | |
cfunction.py | 1 year ago | |
example.py | 1 year ago | |
generator.py | 1 year ago | |
neuron_kernel.py | 1 year ago | |
readme.md | 1 year ago | |
ss_neuron_kernel.py | 10 months ago |
auto_cuda
is an experimental package for creating CUDA codes from the python function automatically.
The final goal is after the user defines a new kind of spiking neuron by python, then the neuron can use the cupy
backend, whose codes are generated by auto_cuda
.
For the moment, we have implemented creating CUDA codes from the python neuronal_charge
function.
Run the following python codes:
from spikingjelly.activation_based.auto_cuda.generator import analyse_graph, gen_forward_codes, gen_backward_codes
from spikingjelly.activation_based import surrogate
import torch
if __name__ == '__main__':
def lif_charge(x: torch.Tensor, v_last: torch.Tensor, tau: float, v_reset: float):
h = v_last + (x - (v_last - v_reset)) / tau
return h
input_nodes, inter_nodes, output_nodes, cmds = analyse_graph(lif_charge, requires_grad=(True, True, False, False))
forward_codes, forward_kernel_name, cuda_cmds = gen_forward_codes(input_nodes, inter_nodes, output_nodes, cmds, hard_reset=True)
backward_codes, backward_kernel_name, input_bp_vars = gen_backward_codes(cuda_cmds, input_nodes, output_nodes, cmds, hard_reset=True, detach_reset=True, surrogate_fuction=surrogate.ATan())
print(f'forward_codes = \n{forward_codes}')
print(f'backward_codes = \n{backward_codes}')
Then we will get the output CUDA codes:
forward_codes =
extern "C" __global__
void forward_kernel_697806161140619033
(const float *x_seq, float *v_v_seq, float *h_seq, float *spike_seq, const float &v_threshold, const float &v_reset, const int &neuron_num, const int &numel, const float &input_tau, const float &input_v_reset)
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < neuron_num)
{
const int dt = neuron_num;
for(int mem_offset = 0; mem_offset < numel; mem_offset += neuron_num)
{
const int t = index + mem_offset;
{
const float input_x = x_seq[t];
const float input_v_last = v_v_seq[t];
float inter_9 = input_v_last - input_v_reset;
float inter_11 = input_x - inter_9;
float inter_13 = inter_11 / input_tau;
float output_h = input_v_last + inter_13;
h_seq[t] = output_h;
}
if (h_seq[t] >= v_threshold)
{
spike_seq[t] = 1.0f;
v_v_seq[t + dt] = v_reset;
}
else
{
spike_seq[t] = 0.0f;
v_v_seq[t + dt] = h_seq[t];
}
}
}
}
backward_codes =
extern "C" __global__
void backward_kernel__3595517059288953692
(const float *grad_spike_seq, const float *grad_v_seq, const float *h_seq, const float *spike_seq, float *grad_x_seq, float *grad_v_init, const float &v_threshold, const float &v_reset, const int &neuron_num, const int &numel, const float &input_tau)
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < neuron_num)
{
float grad_output_h = 0.0f; // grad_output_h will be used recursively
for(int mem_offset = numel - neuron_num; mem_offset >= 0; mem_offset -= neuron_num)
{
const int t = index + mem_offset;
const float over_th = h_seq[t] - v_threshold;
// start: spikingjelly.activation_based.surrogate.ATan.cuda_code
const float sg_ATan_M_PI_2__alpha__x = ((float) 1.57079632679489661923) * 2.0f * over_th;
const float grad_s_to_h = 2.0f / 2.0f / (1.0f + sg_ATan_M_PI_2__alpha__x * sg_ATan_M_PI_2__alpha__x);
// end: spikingjelly.activation_based.surrogate.ATan.cuda_code
const float grad_v_to_h = 1.0f - spike_seq[t];
// output_h = input_v_last + inter_13;
float grad_input_v_last = grad_output_h;
float grad_inter_13 = grad_output_h;
// inter_13 = inter_11 / input_tau;
float grad_inter_11 = grad_inter_13 / input_tau;
// inter_11 = input_x - inter_9;
float grad_input_x = grad_inter_11;
float grad_inter_9 = - grad_inter_11;
// inter_9 = input_v_last - input_v_reset;
grad_input_v_last += grad_inter_9;
//
grad_output_h = grad_spike_seq[t] * grad_s_to_h + (grad_v_seq[t] + grad_input_v_last) * grad_v_to_h;
}
// output_h = input_v_last + inter_13;
float grad_input_v_last = grad_output_h;
float grad_inter_13 = grad_output_h;
// inter_13 = inter_11 / input_tau;
float grad_inter_11 = grad_inter_13 / input_tau;
// inter_11 = input_x - inter_9;
float grad_input_x = grad_inter_11;
float grad_inter_9 = - grad_inter_11;
// inter_9 = input_v_last - input_v_reset;
grad_input_v_last += grad_inter_9;
//
grad_v_init[index] = grad_input_v_last;
}
}
开源脉冲神经网络深度学习框架
https://spikingjelly.readthedocs.io
Python Cuda Markdown
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》