|
- import torch
- import torch.nn as nn
- import numpy as np
- from xcom.pruning.LnStructured import LnStructured
- from xcom.pruning.L1Unstructured import L1Unstructured
- def l1_structured_compress(model_origin,model_pruned):
- masks = []
- w=[]
- for [m0, m1] in zip(model_origin.named_modules(), model_pruned.named_modules()):
- if isinstance(m0[1], nn.Conv2d):
- if m0[1].weight.data.shape!=m1[1].weight.data.shape:
- flag = False
- if m0[1].weight.data.shape[1]!=m1[1].weight.data.shape[1]: #输入通道
- w,flag=channel_prune(m0,m1,masks,model_pruned,flag,w)
- if m0[1].weight.data.shape[0]!=m1[1].weight.data.shape[0]: #输出通道
- filter_prune(m0,m1,masks,flag,w,n=1)
- continue
- else:
- m1[1].weight.data = m0[1].weight.data.clone()
- masks.append(None)
- elif isinstance(m0[1], nn.BatchNorm2d):
- assert isinstance(m1[1], nn.BatchNorm2d), "There should not be bn layer here."
- if m0[1].weight.data.shape!=m1[1].weight.data.shape:
- adjust_bn(m0,m1,masks)
- continue
- set_bn_value(m0,m1)
-
- def l1_unstructured_compress(model_origin,model_pruned):
-
-
- return
-
- pruning_method={
- "l1_structured":l1_structured_compress,
- "l1_unstructured":l1_unstructured_compress
- }
- def structured_compress(model_origin,method,stayed_channels,tr_scrath):
- if(tr_scrath==False):
- modelclass=model_origin.__class__
- model_pruned=modelclass(cfg=stayed_channels)
- pruning_method[method](model_origin,model_pruned)
- else:
- modelclass=model_origin.__class__
- model_pruned=modelclass(cfg=stayed_channels)
-
- return model_pruned,model_origin
-
-
- def channel_prune(m0,m1,masks,model_pruned,flag,w):
- assert len(masks)>0, "masks is empty!"
- if m0[0].endswith('downsample.conv'):
- if model_pruned.config['depth']>=50:
- mask = masks[-4].cpu() #残差链接
- else:
- mask = masks[-3].cpu()
- else:
- mask = masks[-1].cpu()
- idx = np.squeeze(np.argwhere(mask))
- idx_pruned=np.squeeze(np.argwhere(mask==0))
- if idx.size == 1:
- idx_pruned = np.resize(idx_pruned, (1,))
- idx = np.resize(idx, (1,))
- m0[1].weight.data[:, idx_pruned.tolist(), :, :]=0
- w = m0[1].weight.data[:, idx.tolist(), :, :].clone()
- flag=True
- if m0[1].weight.data.shape[0]==m1[1].weight.data.shape[0]:
- masks.append(None)
- return w,flag
- def filter_prune(m0,m1,masks,flag,w,n):
- if m0[0].endswith('downsample.conv'):
- mask = masks[-1].cpu()
- else:
- LnStructured.apply(m0[1],'weight',m1[1].weight.data.shape[0],n,dim=0)
- mask = m0[1].weight_mask.sum(dim=(1,2,3)).cpu()
- idx = np.squeeze(np.argwhere(mask))
- if idx.size == 1:
- idx = np.resize(idx, (1,))
- if(flag==True):
- w=w[idx.tolist(), :, :, :].clone()
- else:
- w = m0[1].weight.data[idx.tolist(), :, :, :].clone()
- m1[1].weight.data = w.clone()
- masks.append(mask)
- def adjust_bn(m0,m1,masks):
- mask = masks[-1].cpu()
- idx = np.squeeze(np.argwhere(mask))
- if idx.size == 1:
- idx = np.resize(idx, (1,))
- m1[1].weight.data = m0[1].weight.data[idx.tolist()].clone()
- m1[1].bias.data = m0[1].bias.data[idx.tolist()].clone()
- m1[1].running_mean = m0[1].running_mean[idx.tolist()].clone()
- m1[1].running_var = m0[1].running_var[idx.tolist()].clone()
- def set_bn_value(m0,m1):
- m1[1].weight.data = m0[1].weight.data.clone()
- m1[1].bias.data = m0[1].bias.data.clone()
- m1[1].running_mean = m0[1].running_mean.clone()
- m1[1].running_var = m0[1].running_var.clone()
|