|
- #将pytorch的模型读取参数,从tensorlayer模型中读取命名,两者整合到一起,然后保存为模型,并在tl模型中加载测试,看看会不会报错,目前执行通过,待验证。
- import os
- os.environ['TL_BACKEND'] = 'tensorflow'
- import numpy as np
- import torch
- import tensorlayer as tl
- from training.ms_voxel_cnn_training import MSVoxelCNN as torch_MSVoxelCNN, index_hr, index_lr
- from training.ms_voxel_cnn_training_tensorlayer import MSVoxelCNN as tensorlayer_MSVoxelCNN
-
- if __name__ == "__main__":
- lv = 0
- dslevel = 1
- for gr in range(6):
- saved_model_path = 'Model/MSVoxelCNN/'
- low_res = int(64 // (2 ** (lv + 1)))
- tl_model = tensorlayer_MSVoxelCNN(2, 1, low_res, 4, gr)
- maxpool_n1 = tl.layers.SequentialLayer( #空的也没关系,就不处理
- [tl.layers.MaxPool3d(filter_size=(2, 2, 2), strides=(2, 2, 2), padding='VALID') for _ in range(dslevel - 1)]
- )
- maxpool_n = tl.layers.SequentialLayer(
- [tl.layers.MaxPool3d(filter_size=(2, 2, 2), strides=(2, 2, 2), padding='VALID') for _ in range(dslevel)]
- )
- block_size=int(low_res*2)
- x = tl.layers.Input(shape=(1, 64, 64, 64, 1))
- if (gr == 0):
- input = maxpool_n(x)
- else:
- input = maxpool_n1(x)
- index = index_hr(gr - 1, block_size, block_size, block_size)
- input = tl.convert_to_numpy(input) #只能转numpy,因为第518行的原因
- input = input[:, index[0][:, None, None], index[1][:, None], index[2], :] #可以使用tf.gather_nd,带batch,索引不用到最后一维度
-
- _, ld, lh, lw, _ = input.shape
- index_0 = index_lr(gr - 1, ld, lh, lw)
- if (index_0 is not None):
- input[:, index_0[0][:, None, None], index_0[1][:, None], index_0[2], :] = 0 #这里没法用tf.gather_nd,也没法用assign
- if (gr == 5):
- input[:, 1:ld:2, 1:lh:2, :, :]=0
- input = tl.convert_to_tensor(input)
- tl_model.init_build(input) #这一步要有,走一遍前向推理,把前面没填的in_channels这些参数补上,再执行build
- tl_model_name = 'G' + str(gr) + '_lres' + str(low_res) +'_tensorlayer.npz'
- tl_model.save_standard_weights(tl_model_name)
- tl_weights = np.load(tl_model_name, allow_pickle=True)
- # print("tensorlayer:")
- tl_names = {}
- for param in tl_weights.keys():
- # print(param,"\t",tl_weights[param].shape)
- key = param.split('/')
- last = key[-1]
- keys = key[0]
- if keys not in tl_names.keys():
- tl_names[keys] = []
- tl_names[keys].append(last)
- tl_key_list = list(tl_names.keys())
-
- pt_model = torch_MSVoxelCNN(2, 1, low_res, 4, gr)
- ckp_path = saved_model_path + 'G' + str(gr) + '_lres' + str(low_res) + '/' + "best_model.pt"
- print(ckp_path)
- checkpoint = torch.load(ckp_path)
- pt_model.load_state_dict(checkpoint['state_dict'])
- #np.savez('torch.npz', params=pt_model.state_dict()) #直接保存读出来有问题
- #weights = np.load('torch.npz', allow_pickle=True)['params']
- save_list_names = []
- save_list_var = []
- # for named, values in pt_model.named_parameters(): #缺少BN参数
- # save_list_names.append(named)
- # save_list_var.append(values.detach().numpy())
- # print(named)
- # print(values.detach().numpy())
- torch_keys = {}
- for param, values in pt_model.state_dict().items():
- if 'num_batches_tracked' in param: #无用参数
- continue
- # print(pt_model.state_dict()[param])
- names = param.split('.')
- last = names.pop(-1) #弹出最后一个字段
- keys = '.'.join(names)
- if keys not in torch_keys.keys():
- torch_keys[keys] = {}
- torch_keys[keys][last] = values
- for param, values in torch_keys.items():
- if 'running_mean' in values.keys(): #替换bn的名字
- values["gamma"] = values.pop("weight")
- values["beta"] = values.pop("bias")
- values["moving_mean"] = values.pop("running_mean")
- values["moving_var"] = values.pop("running_var")
- if 'mask' in values.keys(): #合并mask,并改名字
- mask = values.pop("mask")
- weight = values.pop("weight")
- values["kernel"] = weight * mask
- elif 'weight' in values.keys(): #改名字
- values["filters"] = values.pop("weight")
- if "bias" in values.keys():
- values["biases"] = values.pop("bias")
-
- for param, values in torch_keys.items():
- for key, num in values.items():
- if key == 'kernel' or key == 'filters': #reshape
- values[key] = num.permute(2, 3, 4, 1, 0)
- torch2tl_weights = {}
- for iter,(param, values) in enumerate(torch_keys.items()):
- names = tl_names[tl_key_list[iter]]
- for name in names:
- torch2tl_weights[tl_key_list[iter]+'/'+name] = values[name].numpy()
- # for param, values in torch2tl_weights.items():
- # print(param,"\t",values.shape)
- torch2tl_model_name = 'G' + str(gr) + '_lres' + str(low_res) +'_torch2tl.npz'
- np.savez(torch2tl_model_name, **torch2tl_weights)
- tl_model.load_standard_weights(torch2tl_model_name, skip=False, reshape=False, format='npz_dict')
- # print("torch transform to tensorlayer:")
- # weights = np.load(torch2tl_model_name, allow_pickle=True)
- # for param in weights.keys():
- # print(param,"\t",weights[param].shape)
- ''' save_list_names.append(param)
- save_list_var.append(values.detach().numpy())
- print(param,"\t",pt_model.state_dict()[param].shape)
- if 'Resnet.blocks.2.block.0.1.weight' == param or 'Resnet.blocks.2.block.0.1.bias' == param:
- print(values)
- save_var_dict = {save_list_names[idx]: val for idx, val in enumerate(save_list_var)}
- np.savez('torch.npz', **save_var_dict)
- weights = np.load('torch.npz', allow_pickle=True)
- for param in weights.keys():
- print(param,"\t",weights[param].shape)
- # print(weights[param])
- tl_weights = np.load('tensorlayer.npz', allow_pickle=True)
- print("tensorlayer:")
- for param in tl_weights.keys():
- print(param,"\t",tl_weights[param].shape)
- if 'batchnorm3d_1/beta' == param or 'batchnorm3d_1/gamma' == param:
- print(tl_weights[param])
- '''
|