|
- '''
- # @time:2023/4/4 15:46
- # Author:Tuan
- '''
- import glob
- import os.path
- from datetime import datetime
- import argparse
- import torch.nn
- from torch.nn import CrossEntropyLoss
- from torch.utils.data import DataLoader
- import torchvision.transforms as transforms
- #from utils.dataset.TT_Dataset_new import MyDataset # 读取数据所用函数
- from utils.dataset.TT_dataset_3l import MyDataset#, ThreeDataset
- from model.FCN_ResNet import FCN_ResNet
- # from Model.HrNet.hrnet_vit import HighResolutionNet
- from utils.utils import Logger
- from torch.nn import functional as F
- from utils.losses.loss import CrossEntropyLoss2dLabelSmooth, FocalLoss2d
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- # def estimate(y_label, y_pred):
- # # y_pred[y_label==0]=0
- # # 准确率
- # acc = np.mean(np.equal(y_label.cpu().numpy(), y_pred.cpu().numpy()) + 0)
- #
- # return acc
-
-
-
-
- # 参数
- def get_args_parser():
- parser = argparse.ArgumentParser(add_help=False)
- parser.add_argument('--batch_size', default=32, type=int,help='')
- parser.add_argument('--epoches', default=800, type=int)
- parser.add_argument('--pretrain', default=True,type=bool,
- help='if start pretrain for seg')
- parser.add_argument('--load_weight', default=True,type=bool,
- help='Load pretrain weights')
- parser.add_argument('--ckpt_url', default="",
- help='Load pretrain weights')
- #Model parameters
- parser.add_argument('--model_name', default='FCN_ResNet101', type=str, metavar='MODEL')
- parser.add_argument('--num_classes', default=7, type=int)
- #Optimizer parameters
- parser.add_argument('--weight_decay', type=float, default=0.05)
- parser.add_argument('--betas', type=float, default=(0.9, 0.95))
- parser.add_argument('--lr', type=float, default=0.0001)
- parser.add_argument('--T_0', type=int, default=3,metavar='余弦退火参数')
- parser.add_argument('--T_mult', type=int,default=13, metavar='余弦退火参数')
- parser.add_argument('--lr_min', type=float, default=1e-5, metavar='余弦退火参数')
- parser.add_argument('--seg_scale', type=float, default=0.3,help='第一次分割的损失权重')
- parser.add_argument('--line_scale', type=float, default=0.7,help='第二次分割的损失权重')
- # Dataset parameters
- parser.add_argument('--train_path', default='/dataset', type=str)
- parser.add_argument('--val_path', default='Train/coastline/train_256', type=str)
- parser.add_argument('--acc_save', default=0.80, type=float,help="大于该阈值时保存验证集结果")
- parser.add_argument('--result_path', default='Train/result', type=str,help="验证集结果保存路径")
- parser.add_argument('--model_path', default='/model', type=str)
- # distributed training parameters
- parser.add_argument('--log_dir', default='/model',
- help='path where to tensorboard log')
- parser.add_argument('--log_dir_txt', default='Train/logs',
- help='path where to txt log')
- parser.add_argument('--log_describe',
- default='unet \n'
- 'imagesize: 256*256 \n'
- "for numpy"
- ,
- type=str,)
- args = parser.parse_args()
-
- return args
-
- args = get_args_parser()
- print(args,flush=True)
- result_save_path = os.path.join(args.result_path, args.model_name)
- if os.path.exists(result_save_path) == False:
- os.makedirs(result_save_path)
-
- now = datetime.now()
- now = str(now.month) + '_' + str(now.day) + '_' + str(now.hour + 8) + '_' + str(now.minute)
-
- #日志文件
- # log_path = os.path.join(args.log_dir_txt, args.model_name + "_logs", now + ".log")
- # if os.path.exists(os.path.join(args.log_dir_txt, args.model_name + "_logs")) == False:
- # os.makedirs(os.path.join(args.log_dir_txt, args.model_name + "_logs"))
- #
- # f = open(log_path, 'w')
- # f.close()
- # log = Logger(log_path, level='debug')
- # log.logger.info(args.log_describe)
- # 可视化
- tensorboardPath = os.path.join(args.log_dir, args.model_name, now) # 可视化文件所在的文件夹
- if os.path.exists(tensorboardPath) == False:
- os.makedirs(tensorboardPath)
- # writer = SummaryWriter(tensorboardPath)
-
- # 数据处理
- #训练集
- imagePath = os.path.join(args.train_path,"image")
- lab_seg_Path = os.path.join(args.train_path, "lab_seg")
- lab_sl_Path = os.path.join(args.train_path, "lab_sl")
- lab_line_Path = os.path.join(args.train_path, "lab_line")
-
- # 构建数据集
- # trainDataset = ThreeDataset(imagePath, lab_seg_Path, lab_sl_Path, lab_line_Path)
- trainDataset = MyDataset(imagePath, lab_seg_Path,lab_sl_Path, lab_line_Path)
- trainDatasetloader = DataLoader(trainDataset, args.batch_size, shuffle=True)
- trainLen = len(trainDatasetloader)
- print(f'Lenth of dataset :{trainLen}',flush=True)
-
- # 定义模型
- model = FCN_ResNet( num_classes=args.num_classes, backbone='resnet101',).to(device)
- state_dict = model.state_dict(keep_vars=True)
- # model = HighResolutionNet(num_class=args.num_classes).to(device)
- total = sum([param.nelement() for param in model.parameters()])
- print("Number of parameter: %.2fM" % (total / 1e6),flush=True)
-
- print(args.load_weight)
- if args.load_weight:
- print("Starting to load the pre training model with:", args.ckpt_url, flush=True)
- parmeters = torch.load(args.ckpt_url,map_location=device)
- trained_part = {k: v for k, v in parmeters.items() if k in state_dict.keys()}
- state_dict.update(trained_part)
- model.load_state_dict(state_dict)
- # pretrain_list = glob.glob(os.path.join(args.ckpt_url, '*.pth'))
- # model.load_state_dict(torch.load(args.ckpt_url, map_location=device))
- # model.eval()
- optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=args.betas)
- model.train()
- print("xia you ren wu",flush=True)
- else:
- print("not pre training model",flush=True)
- optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=args.betas)
- print("strat pre_train",flush=True)
-
- # if args.pretrain:
- # '''选择要更新的参数'''
- # # pg0,pg1 = [], []
- # # for k, v in model.named_parameters():
- # # if '_2' in k:
- # # v.requires_grad = True
- # # pg0.append(v)
- # # else:
- # # v.requires_grad = False
- # # pg1.append(v)
- # optimizer = torch.optim.AdamW(model_line.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=args.betas)
- # model_line.train()
- # print("xia you ren wu",flush=True)
- # else:
- # optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=args.betas)
- # print("train",flush=True)
-
-
-
- # 损失函数 优化器
-
- Loss_Facal = FocalLoss2d(alpha=0.1,gamma=4)
- Loss_Cross = CrossEntropyLoss2dLabelSmooth()
- # optimizer = torch.optim.AdamW(pg1, lr=args.lr, weight_decay=args.weight_decay, betas=args.betas)
- scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
- optimizer,
- T_0=args.T_0, # T_0就是初始restart的epoch数目
- T_mult=args.T_mult, # T_mult就是重启之后因子,即每个restart后,T_0 = T_0 * T_mult
- eta_min=args.lr_min # 最低学习率
- )
-
- max_acc = 0.0
- # 训练
- for epoch in range(args.epoches):
- print(
- f"\n----------------------------------------------epoch: {epoch}----------------------------------------------",flush=True)
- loss_total_seg = 0
- loss_total_line = 0
- acc_all = 0
- m = ((1 - epoch / args.epoches) ** 0.9) * (0.9 - 0.9 / 100) + 0.9 / 100
-
- for i, data in enumerate(trainDatasetloader):
- image, label, label_sl, label_line= data
- image = image.to(device)
- label = label.to(device)
- label_sl = label_sl.to(device)
- label_line = label_line.to(device)
- # 梯度清零
- optimizer.zero_grad()
-
- '''forward'''
- output, output_sea = model.forward(image)
- loss_type = 0
- # for i in range(len(output)):
- if args.load_weight:
- loss_type = 0.75 * Loss_Cross(output, label.long()) + 0.25 * Loss_Facal(output, label.long())
- loss_sea = Loss_Cross(output_sea, label_sl.long())
- loss_seg = 0.4*loss_type + 0.6*loss_sea
- loss_total_seg += loss_type
- # loss_seg = loss_type
- '''forward_line'''
- if torch.sum(label_line) == 0 and epoch < 80:
- loss_line = 0
- else:
- output = model.forward_line(output, image)
- loss_line = 0.5 * Loss_Cross(output, label_line.long()) + 0.5 * Loss_Facal(output, label_line.long())
- # else:
- # loss_line = 0
- loss_total_line += loss_line
- else:
- loss_type = 0.75 * Loss_Cross(output, label.long()) + 0.25 * Loss_Facal(output, label.long())
- loss_sea = Loss_Cross(output_sea, label_sl.long())
- loss_seg = loss_type + 1.5*loss_sea
- loss_total_seg += loss_type
- loss_line = 0
- seg_scale = 1
-
- # if label_line.sum() !=0:
-
- # if epoch < 150:
- # seg_scale= 0.90
- # elif epoch>=150 and epoch<=250:
- # seg_scale = 0.8
- # elif epoch>=250 and epoch<=350:
- # seg_scale = 0.65
- # else:
- # seg_scale = 0.5
-
-
- loss_all = seg_scale*loss_seg + (1 - seg_scale)*loss_line
- loss_all.backward()
- optimizer.step()
-
- lr = optimizer.param_groups[0]["lr"] # 当前学习率
-
- print("\r"," train: epoch: {}, step: {}/{}, lr: {}, seg_loss: {} , line_loss: {}".format(epoch, i, trainLen, lr,
- round(float(loss_total_seg / (i + 1)), 4)
- ,round(float(loss_total_line / (i + 1)), 4)),
- end='',flush=True)
-
- # 每一个epoch的平均loss
- scheduler.step()
- epoch_loss_seg = loss_total_seg * 1.0 / trainLen
- epoch_loss_line = loss_total_line * 1.0 / trainLen
- print("\r epoch: {}, epoch_loss_seg: {}, epoch_loss_line: {}".format(epoch, round(float(epoch_loss_seg), 4),round(float(epoch_loss_line), 4)), end='',flush=True)
- # 保存模型
- save_name = str(epoch + 1) + '-' + str(round(float(epoch_loss_seg), 4)) + ".pth" # 模型名称
- save_model_path = os.path.join(args.model_path, args.model_name) # 模型所在文件夹
- model_Path = os.path.join(save_model_path, save_name)
- if os.path.exists(save_model_path) == False:
- os.makedirs(save_model_path)
- #torch.save(model.state_dict(), Path)
-
- # if args.load_weight:
- # save_name = str(epoch + 1) + '-' + str(round(float(epoch_loss_line), 4)) + ".pth" # 模型名称
- # save_moel_line_path = os.path.join(args.model_path, args.model_name+"_line") # 模型所在文件夹
- # if os.path.exists(save_moel_line_path) == False:
- # os.makedirs(save_moel_line_path)
- # model_line_Path = os.path.join(save_moel_line_path, save_name)
- # # model_Path = os.path.join(save_moel_line_path, "seg.pth")
- # if ((epoch + 1) % 1 == 0 and (epoch + 1) < 1000):
- # torch.save(model.state_dict(), model_line_Path)
-
- # 前190个epoch: 每20个epoch保存一次
- if ((epoch + 1) % 20 == 0 and (epoch + 1) > 1):
- torch.save(model.state_dict(), model_Path)
- # 最后10个epoch: 每个epoch保存一个
- elif ((epoch + 1) > 1000):
- torch.save(model.state_dict(), model_Path)
-
-
-
-
-
|