|
- '''
- 删除12->1层的全连接层
- 改为64x6x6->>11
- '''
- import numpy as np
- import logging
- from gauss_morph_load import Morph
- from Simple_FC_D import C3AE
-
- import torch
- from torch.utils.data import DataLoader
- from torch import nn, optim
- from torch.optim import lr_scheduler
- import time
- import copy
- import matplotlib.pyplot as plt
- import torch.nn.functional as F
- from tqdm import tqdm
- import math
-
- plt.style.use('ggplot')
-
-
- # 自定义的loss函数
- # class Center_loss(nn.Module):
- # def __init__(self):
- # super(Center_loss, self).__init__()
-
- # def forward(self, features, labels, center, box):
- # l2loss = nn.MSELoss()
- # diffs = 0
-
- # for i in range(features.size(0)):
- # pos = int((labels[i] - 16) / box)
-
-
- # diff = l2loss(features[i],center[pos])
- # diffs = diffs+diff
-
- # center_loss = (diffs/labels.size(0))/2
-
- # return center_loss
-
-
- # class Center_loss(nn.Module):
- # def __init__(self):
- # super(Center_loss, self).__init__()
- #
- # def forward(self, features, labels, center, box):
- # diffs = torch.autograd.Variable(torch.tensor([], device='cuda:0'))
- # diffs.requires_grad = True
- #
- # for i in range(features.size(0)):
- # pos = int((labels[i] - 16) / box)
- # diff = center[pos] - features[i]
- # diff = diff.unsqueeze(0)
- # diffs = torch.cat((diffs, diff), dim=0)
- #
- # center_loss = torch.norm(diffs, p=2, dim=1)
- # mean_loss = torch.mean(center_loss)
- #
- # return mean_loss
-
- class Center_loss(nn.Module):
- def __init__(self):
- super(Center_loss, self).__init__()
-
- def forward(self, features, labels, center, box):
-
- diffs = torch.autograd.Variable(torch.tensor([],device='cuda:0'))
- diffs.requires_grad = True
-
- other_sum_loss = torch.autograd.Variable(torch.tensor([],device='cuda:0'))
- other_sum_loss.requires_grad = True
-
- for i in range(features.size(0)):
- pos = int((labels[i] - 16) / box)
-
- for j in range(center.size(0)):
- M = math.fabs(j-pos)*box
- diff = center[j] - features[i]
- diff = diff.unsqueeze(0)
- if (j==pos):
- diffs = torch.cat((diffs, diff), dim=0)
- else:
- los = torch.norm(diff, p=2, dim=1)
- other_loss =torch.max(M-los,torch.tensor(0,device='cuda:0'))* 0.5
- other_sum_loss = torch.cat((other_sum_loss, other_loss), dim=0)
-
-
-
- center_loss = torch.norm(diffs, p=2, dim=1) * 0.5
- mean_loss = torch.mean(center_loss)+torch.mean(other_sum_loss)
-
- return mean_loss
-
- def main():
- print("训练开始")
- # 超参数调整
- train_epoch = 240
- batch_size = 128
- lr = 0.002
- weight_decay = 0.001
- step_size = 40
- gamma = 0.1
- stop_epoch = 999
- save_name = 'HU_MORPH_23_03_28_21_21'
-
- device = torch.device('cuda:0')
- best_acc = 0
- beta = 9
-
- # 载入morph数据集
- Morph_train = Morph(mode='train')
- Morph_test = Morph(mode='val')
- loader_train = DataLoader(Morph_train, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
- loader_test = DataLoader(Morph_test, batch_size=batch_size, drop_last=True, num_workers=4)
- train_recorder = 'acc.txt'
-
- # 初始化记录训练数据的参数
- best_val_loss = 500
- best_mae = 500
- # 挑选模型
- model = C3AE().to(device)
-
- # 目标函数、优化器选择和学习率调整机制
- l1loss = nn.L1Loss()
- klloss = nn.KLDivLoss(reduction='batchmean')
- optimizer = optim.Adam(model.parameters(), lr=lr, betas=[0.9, 0.999], eps=1e-8, weight_decay=weight_decay)
- scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[120, 200, 240], gamma=gamma)
- image_x = [] # 记录训练loss,用于生成训练数据图
- image_y = [] # 记录测试loss,用于生成训练数据图
- graph_mae = []
-
- c0 = torch.autograd.Variable(torch.randn(3, 32, device="cuda:0"))
- c1 = torch.autograd.Variable(torch.randn(10, 64, device="cuda:0"))
- c2 = torch.autograd.Variable(torch.randn(33, 128, device="cuda:0"))
- # 训练过程
- for epoch in range(train_epoch):
- # print('epoch:',epoch,"lr:",scheduler.get_last_lr())
- sum_loss = 0
- sum_lenth = 0
- loader_train = tqdm(loader_train)
-
- for idx, (img_64, label, label_vec3, label_vec10) in enumerate(loader_train):
- img_64, label, label_vec3, label_vec10 = img_64.to(device), label.to(device), label_vec3.to(
- device), label_vec10.to(device)
- ce_vector, logits, cf0, cf1, cf2, c0, c1, c2 = model(img_64, mode="train", label=label,feature_center0=c0,feature_center1=c1,feature_center2=c2,lr=lr)
-
- # 中心损失计算
- center_loss = Center_loss()
- loss_center0 = center_loss(features=cf0, labels=label, center=c0, box=18.6)
- loss_center1 = center_loss(features=cf1, labels=label, center=c1, box=5.6)
- loss_center2 = center_loss(features=cf2, labels=label, center=c2, box=1.7)
- # print(loss_center0,loss_center1,loss_center2)
- loss_cus = loss_center1 + loss_center2 + loss_center0
- ce_vector = ce_vector.view_as(label_vec10)
- ce_vector = F.log_softmax(ce_vector, dim=1)
-
- logits = logits.view_as(label)
- KL_LOSS = klloss(ce_vector, label_vec10)
- L1_LOSS = l1loss(logits, label)
- train_loss = beta * KL_LOSS + L1_LOSS
- loss = loss_cus+train_loss
-
- optimizer.zero_grad()
- loss.backward()
-
-
-
- # 计算每个epoch的损失函数
- sum_loss += loss.item() * len(label)
- sum_lenth += len(label)
- epoch_loss = sum_loss / sum_lenth
- loader_train.desc = "[epoch {}] mean loss {}".format(epoch, round(epoch_loss, 7))
- optimizer.step()
- # print(epoch,idx,loss.item())
- image_x.append(epoch_loss)
- scheduler.step()
-
- # 测试过程
- model.eval()
- with torch.no_grad():
- test_sum_loss = 0
- total_num = 0
- mae = 0
- for img_64, label, label_vec3, label_vec10 in loader_test:
- img_64, label, label_vec3, label_vec10 = img_64.to(device), label.to(device), label_vec3.to(
- device), label_vec10.to(device)
- ce_vector, logits, cf0, cf1, cf2, c0, c1, c2 = model(img_64,label=label,feature_center0=c0,feature_center1=c1,feature_center2=c2,lr=lr)
-
- ce_vector = ce_vector.view_as(label_vec10)
- ce_vector = F.log_softmax(ce_vector, dim=1)
-
- logits = logits.view_as(label)
- KL_LOSS = klloss(ce_vector, label_vec10)
- L1_LOSS = l1loss(logits, label)
- test_loss = beta * KL_LOSS + L1_LOSS
-
- cur_mae = torch.abs(logits - label)
- cur_mae = cur_mae.mean()
- label = label.detach().cpu().numpy()
-
- mae += (cur_mae.item() * len(label))
- test_sum_loss += (test_loss.item() * len(label))
- total_num += len(label)
- test_avg_loss = test_sum_loss / total_num
- test_mae = mae / total_num
- graph_mae.append(test_mae)
- image_y.append(test_avg_loss)
- print('epoch:',epoch,'test_avg_loss:',test_avg_loss,'mae',test_mae)
- end_time = time.asctime(time.localtime(time.time()))
- old_best_loss = best_val_loss
-
- # 记录loss最好的epoch,等待保存
- if test_avg_loss <= best_val_loss:
- best_val_loss = test_avg_loss
- best_model_wts = copy.deepcopy(model.state_dict())
- if test_mae <= best_mae:
- best_mae = test_mae
-
- if best_val_loss == old_best_loss:
- stop += 1
- else:
- stop = 0
-
- # 记录训练日志
- with open('/code/' + train_recorder, 'a') as F1:
- F1.write('\n' + 'epoch:' + str(epoch))
- F1.write(' ' + 'train_loss:' + str('%.6f' % epoch_loss))
- F1.write(' ' + 'val_loss:' + str('%.6f' % test_avg_loss))
- F1.write(' ' + 'best_val_loss:' + str('%.6f' % best_val_loss))
- F1.write(' ' + 'mae:' + str('%.6f' % test_mae))
- F1.write(' ' + 'best_mae:' + str('%.6f' % best_mae))
- F1.write(' ' + 'finish time:' + str(end_time))
- # print('stop:',stop)
- if stop >= stop_epoch:
- break
-
- # model.load_state_dict(best_model_wts)
- torch.save(model, '/model/' + save_name + str(epoch) + '.pkl')
- torch.save(model.state_dict(), '/model/param_' + save_name + str(epoch) + '.pkl')
- fig = plt.figure(figsize=[10, 10])
- plt.subplot(111)
-
- # print(epoch)
- x_values = list(range(1, epoch + 2)) # 设置x轴参数
- plt.plot(x_values, image_x, label='train loss') # 生成连续函数
- plt.plot(x_values, image_y, label='test loss')
- plt.legend(['train loss', 'test loss'])
- plt.savefig(fname="/model/" + save_name + '.eps')
- plt.show()
-
- fig = plt.figure(figsize=[10, 10])
- plt.subplot(111)
-
- plt.plot(x_values, graph_mae, label='test mae') # 生成连续函数
- plt.legend(['test mae'])
- plt.savefig(fname="/model/" + save_name + '_mae.eps')
- plt.show()
-
-
- if __name__ == '__main__':
- main()
|