|
- '''
- # @time:2023/6/5 10:51
- # Author:Tuan
- # @File:train.py
- '''
- import glob
- import os.path
- from datetime import datetime
- import argparse
- import torch.nn
- from torch.nn import CrossEntropyLoss, MSELoss
- from torch.utils.data import DataLoader
- from tqdm import tqdm
- # from utils.dataset.TT_Dataset_new import MyDataset # 读取数据所用函数
- from utils import MyDataset#, ThreeDataset
- from modules import UNet
- from ddpm import Diffusion
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- # def estimate(y_label, y_pred):
- # # y_pred[y_label==0]=0
- # # 准确率
- # acc = np.mean(np.equal(y_label.cpu().numpy(), y_pred.cpu().numpy()) + 0)
- #
- # return acc
-
- # 参数
- def get_args_parser():
- parser = argparse.ArgumentParser(add_help=False)
- parser.add_argument('--batch_size', default=2, type=int,help='')
- parser.add_argument('--epoches', default=260, type=int)
- parser.add_argument('--pretrain', default=False,type=bool,
- help='if start pretrain for seg')
- parser.add_argument('--load_weight', default=False,type=bool,
- help='Load pretrain weights')
- parser.add_argument('--ckpt_url', default="",
- help='Load pretrain weights')
- #Model parameters
- parser.add_argument('--model_name', default='hrnet_m', type=str, metavar='MODEL')
- parser.add_argument('--num_classes', default=7, type=int)
- #Optimizer parameters
- parser.add_argument('--weight_decay', type=float, default=0.05)
- parser.add_argument('--betas', type=float, default=(0.9, 0.95))
- parser.add_argument('--lr', type=float, default=1e-4)
- parser.add_argument('--T_0', type=int, default=3,metavar='余弦退火参数')
- parser.add_argument('--T_mult', type=int,default=13, metavar='余弦退火参数')
- parser.add_argument('--lr_min', type=float, default=1e-5, metavar='余弦退火参数')
- # Dataset parameters
- parser.add_argument('--train_path', default='data/dataset/cut', type=str)
- parser.add_argument('--result_path', default='Train/result', type=str,help="验证集结果保存路径")
- parser.add_argument('--model_path', default='/model', type=str)
- # distributed training parameters
- parser.add_argument('--log_dir', default='/model',
- help='path where to tensorboard log')
- parser.add_argument('--log_dir_txt', default='Train/logs',
- help='path where to txt log')
- parser.add_argument('--log_describe',
- default='unet \n'
- 'imagesize: 256*256 \n'
- "for numpy"
- ,
- type=str,)
- args = parser.parse_args()
-
- return args
-
- args = get_args_parser()
- print(args,flush=True)
- result_save_path = os.path.join(args.result_path, args.model_name)
- if os.path.exists(result_save_path) == False:
- os.makedirs(result_save_path)
-
- now = datetime.now()
- now = str(now.month) + '_' + str(now.day) + '_' + str(now.hour + 8) + '_' + str(now.minute)
-
- # 数据处理
- #训练集
- imagePath = os.path.join(args.train_path,"image")
- lab_seg_Path = os.path.join(args.train_path, "lab")
-
- # 构建数据集
- # trainDataset = ThreeDataset(imagePath, lab_seg_Path, lab_sl_Path, lab_line_Path)
- trainDataset = MyDataset(imagePath, lab_seg_Path)
- trainDatasetloader = DataLoader(trainDataset, args.batch_size, shuffle=True)
- trainLen = len(trainDatasetloader)
- print(f'Lenth of dataset :{trainLen}',flush=True)
-
- # 定义模型
- model = UNet().to(device)
- model.train()
- total = sum([param.nelement() for param in model.parameters()])
- print("Number of parameter: %.2fM" % (total / 1e6),flush=True)
- optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=args.betas)
- diffusion = Diffusion(img_size=256, device=device)
- # 损失函数 优化器
- Loss_type = CrossEntropyLoss()
- Loss_mse = MSELoss()
- scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
- optimizer,
- T_0=args.T_0, # T_0就是初始restart的epoch数目
- T_mult=args.T_mult, # T_mult就是重启之后因子,即每个restart后,T_0 = T_0 * T_mult
- eta_min=args.lr_min # 最低学习率
- )
-
- max_acc = 0.0
- # 训练
- for epoch in range(args.epoches):
- print(
- f"\n----------------------------------------------epoch: {epoch}----------------------------------------------",flush=True)
- loss_total_seg = 0
- loss_total_line = 0
- # pbar = tqdm(trainDatasetloader)
-
- for i, data in enumerate(trainDatasetloader):
-
- image, label = data
- image = image.to(device)
- label = label.to(device)
- # 梯度清零
- optimizer.zero_grad()
- '''forward'''
- label = label.unsqueeze(1)#.repeat(1,4,1,1)
- t = diffusion.sample_timesteps(label.shape[0]).to(device)
- # 添加噪声,也称向前过程
- x_t, noise = diffusion.noise_images(label, t)
- # 预测噪声,也称逆向过程
- predicted_noise = model(x_t, t)
- loss = Loss_mse(noise, predicted_noise)
-
- # seg_1 = model.forward(image)
- # loss_type1 = 0.75 * Loss_type(seg_1, label.long())
- #
- # loss_total_seg = loss_type1
-
- loss.backward()
- optimizer.step()
-
- lr = optimizer.param_groups[0]["lr"] # 当前学习率
- # pbar.set_postfix(MSE=loss.item())
-
- print("\r"," train: epoch: {}, step: {}/{}, lr: {}, seg_loss: {} ".format(epoch, i, trainLen, lr,
- round(float(loss_total_seg / (i + 1)), 4)
- ),
- end='',flush=True)
-
- # 每一个epoch的平均loss
- scheduler.step()
- epoch_loss_seg = loss_total_seg * 1.0 / trainLen
- epoch_loss_line = loss_total_line * 1.0 / trainLen
- print("\r epoch: {}, epoch_loss_seg: {}, ".format(epoch, round(float(epoch_loss_seg), 4)), end='',flush=True)
- # 保存模型
- save_name = str(epoch + 1) + '-' + str(round(float(epoch_loss_seg), 4)) + ".pth" # 模型名称
- save_model_path = os.path.join(args.model_path, args.model_name) # 模型所在文件夹
- model_Path = os.path.join(save_model_path, save_name)
- if os.path.exists(save_model_path) == False:
- os.makedirs(save_model_path)
-
- if ((epoch + 1) % 10 == 0 and (epoch + 1) > 1):
- torch.save(model.state_dict(), model_Path)
- # 最后10个epoch: 每个epoch保存一个
- elif ((epoch + 1) > 1000):
- torch.save(model.state_dict(), model_Path)
-
-
-
-
-
-
-
|