|
- import copy
- import os
- from typing import Dict
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import logging
- import numpy as np
- from . import neuron, functional, layer
-
- '''
- TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
- for t in range(x_seq.shape[0]):
-
- 不支持inplace操作,因此形如x[t] = y之类的操作都无效,x[t]并不会被设置为y,且不会报错
-
- 不支持5D的tensor参与模型编译,在任何位置都不能出现超过4D的tensor
-
- '''
-
- class BaseNode(nn.Module):
- def __init__(self, v_threshold: float = 1., v_reset: float = 0., step_mode='s', T: int = None,
- return_v: bool = False):
- super().__init__()
- self.v_threshold = v_threshold
- self.v_reset = v_reset
- self.step_mode = step_mode
- self.T = T
- self.return_v = return_v
-
- def neuronal_charge(self, x: torch.Tensor, v: torch.Tensor):
- raise NotImplementedError
-
- def single_step_forward(self, x: torch.Tensor, v: torch.Tensor = None):
- if v is None:
- v = torch.zeros_like(x)
- v = self.neuronal_charge(x, v)
-
- spike = (v >= self.v_threshold).to(x)
- if self.v_reset is None:
- v = v - spike * self.v_threshold
- else:
- v = (1. - spike) * v + spike * self.v_reset
-
- return spike, v
-
- def multi_step_forward(self, x_seq: torch.Tensor, v_init: torch.Tensor = None):
- if v_init is None:
- v = torch.zeros_like(x_seq[0])
- else:
- v = v_init
- spike_seq = []
- for t in range(self.T):
- spike, v = self.single_step_forward(x_seq[t], v)
- spike_seq.append(spike.unsqueeze(0))
-
- spike_seq = torch.cat(spike_seq)
- return spike_seq, v
-
- def forward(self, x: torch.Tensor, v: torch.Tensor = None):
- if self.step_mode == 's':
- spike, v = self.single_step_forward(x, v)
- if self.return_v:
- return spike, v
- else:
- return spike
- elif self.step_mode == 'm':
- x_shape = x.shape
-
- # 起始 编译通过-------------------
- x = x.reshape(self.T, x.shape[0] // self.T, -1)
- # 终结 编译通过-------------------
-
- # 起始 编译报错-------------------
- # x = x.flatten(1)
- # x = unfold_seq(self.T, x)
- # 终结 编译报错-------------------
-
-
- if v is not None:
- v = v.flatten()
- spike_seq, v = self.multi_step_forward(x, v)
-
- spike_seq = spike_seq.flatten(0, 1).reshape(x_shape)
-
- if self.return_v:
- v = v.reshape([x_shape[0] // self.T] + list(x_shape[1:]))
- return spike_seq, v
- else:
- return spike_seq
-
-
- class IFNode(BaseNode):
- def neuronal_charge(self, x: torch.Tensor, v: torch.Tensor):
- return x + v
-
- class LIFNode(BaseNode):
- def __init__(self, tau: float = 2., decay_input: bool = True, v_threshold: float = 1.,
- v_reset: float = 0., step_mode='s', T: int = None, return_v: bool = False):
- super().__init__(v_threshold, v_reset, step_mode, T, return_v)
- self.decay = 1. / tau
- self.decay_input = decay_input
-
- def neuronal_charge(self, x: torch.Tensor, v: torch.Tensor):
-
- if self.v_reset is None:
- v = (1. - self.decay) * v
- else:
- v = (1. - self.decay) * (v - self.v_reset)
-
- if self.decay_input:
- x = x * self.decay
-
- return v + x
-
-
-
- def to_lynxi_supported_module(m_in: nn.Module, T: int):
- if isinstance(m_in, layer.Conv2d):
- m_out = nn.Conv2d(in_channels=m_in.in_channels, out_channels=m_in.out_channels,
- kernel_size=m_in.kernel_size, stride=m_in.stride, padding=m_in.padding,
- dilation=m_in.dilation, groups=m_in.groups, bias=m_in.bias is not None,
- padding_mode=m_in.padding_mode)
- m_out.load_state_dict(m_in.state_dict())
-
-
-
- elif isinstance(m_in, layer.BatchNorm2d):
- m_out = nn.BatchNorm2d(num_features=m_in.num_features, eps=m_in.eps, momentum=m_in.momentum,
- affine=m_in.affine, track_running_stats=m_in.affine)
- m_out.load_state_dict(m_in.state_dict())
-
-
-
- elif isinstance(m_in, layer.MaxPool2d):
- m_out = nn.MaxPool2d(kernel_size=m_in.kernel_size, stride=m_in.stride, padding=m_in.padding,
- dilation=m_in.dilation, return_indices=m_in.return_indices, ceil_mode=m_in.ceil_mode)
-
-
- elif isinstance(m_in, layer.AvgPool2d):
- m_out = nn.AvgPool2d(kernel_size=m_in.kernel_size, stride=m_in.stride, padding=m_in.padding,
- ceil_mode=m_in.ceil_mode, count_include_pad=m_in.count_include_pad,
- divisor_override=m_in.divisor_override)
-
-
- elif isinstance(m_in, layer.AdaptiveAvgPool2d):
- m_out = nn.AdaptiveAvgPool2d(output_size=m_in.output_size)
-
-
- elif isinstance(m_in, layer.Flatten):
- m_out = nn.Flatten(start_dim=m_in.start_dim, end_dim=m_in.end_dim)
-
-
- elif isinstance(m_in, neuron.IFNode):
- m_out = IFNode(v_threshold=m_in.v_threshold, v_reset=m_in.v_reset, step_mode=m_in.step_mode, T=T,
- return_v=False)
-
-
- elif isinstance(m_in, neuron.LIFNode):
- m_out = LIFNode(tau=m_in.tau, v_threshold=m_in.v_threshold, v_reset=m_in.v_reset, decay_input=m_in.decay_input,
- step_mode=m_in.step_mode, T=T,
- return_v=False)
-
- else:
- logging.critical(f'{type(m_in)} is not processed and the origin module is used for lynxi compiling.')
- m_out = copy.deepcopy(m_in).cpu()
-
- return m_out
-
- def to_lynxi_supported_modules(net: list or tuple or nn.Sequential, T: int):
- output_net = []
- for i in range(net.__len__()):
- m_in = net[i]
- m_out = to_lynxi_supported_module(m_in, T)
- output_net.append(m_out)
-
- return output_net
-
-
- try:
- '''
- 适配灵汐科技的芯片
-
- '''
- import lyngor
- import lynpy
- logging.info(f'lynpy.version={lynpy.version}')
- logging.info(f'lyngor.version={lyngor.version}')
-
-
- def torch_tensor_to_lynxi(x: torch.Tensor, device_id: int = 0, to_apu: bool = True):
- x_size_in_byte = x.element_size() * x.numel()
- x = x.cpu().detach().numpy()
- x = lynpy.Tensor(dev_id=device_id, size=x_size_in_byte).from_numpy(x)
- if to_apu:
- x = x.apu()
- return x
-
-
- def lynxi_tensor_to_torch(x: lynpy.Tensor, shape: tuple or list = None, dtype: str = None):
- if shape is not None and dtype is not None:
- x = x.view_as(shape, dtype)
- if x.devptr is not None:
- x = x.cpu()
- x = torch.from_numpy(x.numpy())
- return x
-
-
- def compile_lynxi_model(output_dir: str, net: nn.Module, in_data_type: str = 'float32', out_data_type: str = 'float32', input_shape_dict : Dict = {}):
- model = lyngor.DLModel()
- model.load(net, model_type='Pytorch', in_type=in_data_type, out_type=out_data_type,
- inputs_dict=input_shape_dict)
- offline_builder = lyngor.Builder(target='apu', is_map=True)
- out_path = offline_builder.build(model.graph, model.params,
- out_path=output_dir, apu_only=True)
- print(os.listdir(out_path))
- return os.path.join(out_path, 'Net_0')
-
- def load_lynxi_model(device_id: int, model_path: str):
- return lynpy.Model(dev_id=device_id, path=model_path)
-
-
- except BaseException as e:
- logging.info(f'spikingjelly.activation_based.lynxi_exchange: {e}')
-
|