|
- '''
- # @time:2023/4/7 20:01
- # Author:Tuan
- # @File:train_line.py
- '''
- '''
- # @time:2023/3/13 15:46
- # Author:Tuan
- '''
- import glob
- import os.path
- from datetime import datetime
- import argparse
- import torch.nn
- from torch.nn import CrossEntropyLoss
- from torch.utils.data import DataLoader
- import torchvision.transforms as transforms
- #from utils.dataset.TT_Dataset_new import MyDataset # 读取数据所用函数
- from utils.dataset.TT_dataset_3l import MyDataset_up#, ThreeDataset
- from Model.SegFormer.segformer import SegFormer
- from Model.UNet.unet import UNet
- from Model.DF.DFM import Line_Net
- from utils.losses.loss import FocalLoss2d
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- # def estimate(y_label, y_pred):
- # # y_pred[y_label==0]=0
- # # 准确率
- # acc = np.mean(np.equal(y_label.cpu().numpy(), y_pred.cpu().numpy()) + 0)
- #
- # return acc
-
-
-
-
- # 参数
- def get_args_parser():
- parser = argparse.ArgumentParser(add_help=False)
- parser.add_argument('--batch_size', default=32, type=int,help='')
- parser.add_argument('--epoches', default=500, type=int)
- parser.add_argument('--pretrain', default=False,type=bool,
-
- help='if start pretrain for seg')
- parser.add_argument('--load_weight', default=False,type=bool,
- help='Load pretrain weights')
- parser.add_argument('--ckpt_url', default="",
- help='Load pretrain weights')
- #Model parameters
- parser.add_argument('--model_name', default='segformer', type=str, metavar='MODEL')
- parser.add_argument('--num_classes', default=7, type=int)
- #Optimizer parameters
- parser.add_argument('--weight_decay', type=float, default=0.05)
- parser.add_argument('--betas', type=float, default=(0.9, 0.95))
- parser.add_argument('--lr', type=float, default=1e-4)
- parser.add_argument('--T_0', type=int, default=3,metavar='余弦退火参数')
- parser.add_argument('--T_mult', type=int,default=13, metavar='余弦退火参数')
- parser.add_argument('--lr_min', type=float, default=1e-50, metavar='余弦退火参数')
- parser.add_argument('--seg_scale', type=float, default=0.3,help='第一次分割的损失权重')
- parser.add_argument('--line_scale', type=float, default=0.7,help='第二次分割的损失权重')
- # Dataset parameters
- parser.add_argument('--train_path', default='/dataset', type=str)
- parser.add_argument('--val_path', default='Train/coastline/train_256', type=str)
- parser.add_argument('--acc_save', default=0.80, type=float,help="大于该阈值时保存验证集结果")
- parser.add_argument('--result_path', default='Train/result', type=str,help="验证集结果保存路径")
- parser.add_argument('--model_path', default='/model', type=str)
- # distributed training parameters
- parser.add_argument('--log_dir', default='/model',
- help='path where to tensorboard log')
- parser.add_argument('--log_dir_txt', default='Train/logs',
- help='path where to txt log')
- parser.add_argument('--log_describe',
- default='unet \n'
- 'imagesize: 256*256 \n'
- "for numpy"
- ,
- type=str,)
- args = parser.parse_args()
-
- return args
-
- args = get_args_parser()
- print(args,flush=True)
- result_save_path = os.path.join(args.result_path, args.model_name)
- if os.path.exists(result_save_path) == False:
- os.makedirs(result_save_path)
-
- now = datetime.now()
- now = str(now.month) + '_' + str(now.day) + '_' + str(now.hour + 8) + '_' + str(now.minute)
-
- # 可视化
- tensorboardPath = os.path.join(args.log_dir, args.model_name, now) # 可视化文件所在的文件夹
- if os.path.exists(tensorboardPath) == False:
- os.makedirs(tensorboardPath)
- # writer = SummaryWriter(tensorboardPath)
-
- # 数据处理
- #训练集
- imagePath = os.path.join(args.train_path,"image")
- lab_seg_Path = os.path.join(args.train_path, "lab_seg")
- lab_sl_Path = os.path.join(args.train_path, "lab_sl")
- lab_line_Path = os.path.join(args.train_path, "lab_line")
-
- # 构建数据集
- # trainDataset = ThreeDataset(imagePath, lab_seg_Path, lab_sl_Path, lab_line_Path)
- trainDataset = MyDataset_up( lab_seg_Path, lab_line_Path)
- trainDatasetloader = DataLoader(trainDataset, args.batch_size, shuffle=True)
- trainLen = len(trainDatasetloader)
- print(f'Lenth of dataset :{trainLen}',flush=True)
-
- # 定义模型
- model = Line_Net(in_channels=7,numclass=7).to(device)
-
- state_dict = model.state_dict(keep_vars=True)
- # model = HighResolutionNet(num_class=args.num_classes).to(device)
- total = sum([param.nelement() for param in model.parameters()])
- print("Number of parameter: %.2fM" % (total / 1e6),flush=True)
-
- # 损失函数 优化器
-
- Loss_Facal = FocalLoss2d(alpha=0.25,gamma=2)
- Loss_type = CrossEntropyLoss()
- Loss_sl = CrossEntropyLoss()
- Loss_line = CrossEntropyLoss()
- optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.01, betas=(0.9, 0.95))
- # optimizer = torch.optim.AdamW(pg1, lr=args.lr, weight_decay=args.weight_decay, betas=args.betas)
- # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
- # optimizer,
- # T_0=args.T_0, # T_0就是初始restart的epoch数目
- # T_mult=args.T_mult, # T_mult就是重启之后因子,即每个restart后,T_0 = T_0 * T_mult
- # eta_min=args.lr_min # 最低学习率
- # )
-
- max_acc = 0.0
- # 训练
- for epoch in range(args.epoches):
- print(
- f"\n----------------------------------------------epoch: {epoch}----------------------------------------------",flush=True)
- loss_total_seg = 0
- loss_total_line = 0
- acc_all = 0
- m = ((1 - epoch / args.epoches) ** 0.9) * (0.9 - 0.9 / 100) + 0.9 / 100
-
- for i, data in enumerate(trainDatasetloader):
- image, label= data
- image = image.to(device)
- label = label.to(device)
- # 梯度清零
- optimizer.zero_grad()
-
- '''forward'''
- output = model.forward(image)
- loss_all = 0.75 * Loss_type(output, label.long()) + 0.25 * Loss_Facal(output, label.long())
- loss_total_seg += loss_all
-
- loss_all.backward()
- optimizer.step()
-
- lr = optimizer.param_groups[0]["lr"] # 当前学习率
-
- # print("\r"," train: epoch: {}, step: {}/{}, lr: {}, seg_loss: {} , line_loss: {}".format(epoch, i, trainLen, lr,
- # round(float(loss_total_seg / (i + 1)), 4)
- # ,round(float(loss_total_line / (i + 1)), 4)),
- # end='',flush=True)
-
- # 每一个epoch的平均loss
- # scheduler.step()
- epoch_loss_seg = loss_total_seg * 1.0 / trainLen
- epoch_loss_line = loss_total_line * 1.0 / trainLen
- print("\r epoch: {}, epoch_loss_seg: {}, epoch_loss_line: {}".format(epoch, round(float(epoch_loss_seg), 4),round(float(epoch_loss_line), 4)), end='',flush=True)
- # 保存模型
- save_name = str(epoch + 1) + '-' + str(round(float(epoch_loss_seg), 4)) + ".pth" # 模型名称
- save_model_path = os.path.join(args.model_path, args.model_name) # 模型所在文件夹
- model_Path = os.path.join(save_model_path, save_name)
- if os.path.exists(save_model_path) == False:
- os.makedirs(save_model_path)
- #torch.save(model.state_dict(), Path)
-
- # if args.load_weight:
- # save_name = str(epoch + 1) + '-' + str(round(float(epoch_loss_line), 4)) + ".pth" # 模型名称
- # save_moel_line_path = os.path.join(args.model_path, args.model_name+"_line") # 模型所在文件夹
- # if os.path.exists(save_moel_line_path) == False:
- # os.makedirs(save_moel_line_path)
- # model_line_Path = os.path.join(save_moel_line_path, save_name)
- # # model_Path = os.path.join(save_moel_line_path, "seg.pth")
- # if ((epoch + 1) % 1 == 0 and (epoch + 1) < 1000):
- # torch.save(model.state_dict(), model_line_Path)
-
- # 前190个epoch: 每20个epoch保存一次
- if ((epoch + 1) % 5 == 0 and (epoch + 1) > 0):
- torch.save(model.state_dict(), model_Path)
- # 最后10个epoch: 每个epoch保存一个
- elif ((epoch + 1) > 1000):
- torch.save(model.state_dict(), model_Path)
-
-
-
-
-
|