|
- import glob
- import os
- import time
- import datetime
- import cv2
- import torch
- from models.models.dsamnet import DSAMNet
- from utils import EMA
- from utils.c_train_and_eval import train_one_epoch, evaluate, create_lr_scheduler
- from utils.change_data_aug import MyDataset
- # from models.NestedUNet.Models import SNUNet_ECAM
- from utils.c_distributed_utils import ConfusionMatrix
- from models.FC.FC_EF import FresUNet
- from models.models.siamunet_conc import SiamUNet_conc
- from models.models.siamunet_diff import SiamUNet_diff
- from models.models.stanet import STANet
- from models.models.ifn import DSIFN
- from models.models.cdnet import CDNet
- from models.models import unet
- from models.models.snunet import SNUNet
- # from models.models.myne_corsst import MyNet
- # from models.models.mynet_corsst2 import MyNet
- # from models.models.mynet import MyNet
- from models.mynet.mynet import MyNet
- from argparse import Namespace
- from models.seg_model.unet import UNet
- import warnings
- import wandb
- from utils.distributed_utils import set_seed
-
- warnings.filterwarnings("ignore")
- """
- 读取数据集:RGB三通道,0-255范围内
- 变化检测
- """
- os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
- import math
-
-
- def parse_args():
- args = Namespace(
- project_name='ContrastModel',
- batch_size=64,
- data_path="D:\Datasets\Data_CD\CD_Data_GZ\CD_Data_GZ\\256\\train",
- val_path="D:\Datasets\Data_CD\CD_Data_GZ\CD_Data_GZ\\256\\val",
- out_path="./output",
- device="cuda",
- num_classes=1,
- lr=0.0004,
- momentum=0.9,
- optim_type='Adam',
- print_freq=50,
- epochs=100,
- resume="",
- start_epoch=0,
- save_path='checkpoint.pt',
- ckpt_url=r"",
- amp=False,
- weight_decay=1e-4,
- seed=10,
- model_ema=False,
- model_ema_decay=0.99998,
- model_ema_steps=32)
- return args
-
-
- args = parse_args()
-
-
- def train(args=args):
- device = torch.device(args.device if torch.cuda.is_available() else "cpu")
-
- batch_size = args.batch_size
- num_classes = args.num_classes + 1
-
- # 用来保存训练以及验证过程中信息
- results_file = "results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
- train_dataset = MyDataset(args.data_path)
- val_dataset = MyDataset(args.val_path)
-
- num_workers = 4
- print(num_workers)
- train_loader = torch.utils.data.DataLoader(train_dataset,
- batch_size=batch_size,
- num_workers=num_workers,
- shuffle=True,
- pin_memory=True
- )
-
- val_loader = torch.utils.data.DataLoader(val_dataset,
- batch_size=batch_size,
- num_workers=num_workers,
- pin_memory=True
- )
- # model = SNUNet(3,2)
- model = MyNet(3, 2)
- # model = UNet(in_channels=3, num_classes=2)
- # model = FresUNet(6, 2)
- # model = SiamUnet_conc(3, 2)
- # model = SiamUnet_diff(3, 2).
- # model = DSIFN() # out5 h,w, out4 h/2,w2, out3 h/4,w/4, out2 h/8,w/8, out1 h/16,w/16
- # model = DSAMNet(in_ch=3,out_ch=2) # dist, ds2, ds3 都是b 2 h w 跟输入大小一致
- # model = STANet(in_ch=3)
- model.to(device)
- model_ema = None
- if args.model_ema:
- adjust = 1 * args.batch_size * args.model_ema_steps / args.epochs
- alpha = 1.0 - args.model_ema_decay
- alpha = min(1.0, alpha * adjust)
- model_ema = EMA.ExponentialMovingAverage(model, device=device, decay=1.0 - alpha)
- optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
- lr_scheduler = create_lr_scheduler(optimizer, len(train_loader), args.epochs,
- warmup=True, warmup_epochs=5)
- scaler = torch.cuda.amp.GradScaler() if args.amp else None
- if args.ckpt_url:
- print("使用预训练模型", args.ckpt_url)
- checkpoint = torch.load(args.ckpt_url, map_location='cpu')
- model.load_state_dict(checkpoint['model'])
- # 是否继续训练
- if args.resume:
- checkpoint = torch.load(args.resume, map_location='cpu')
- model.load_state_dict(checkpoint['model'])
- optimizer.load_state_dict(checkpoint['optimizer'])
- lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
- args.start_epoch = checkpoint['epoch'] + 1
- # 混合精度训练
- if args.amp:
- scaler.load_state_dict(checkpoint["scaler"])
-
- # 开始时间
- start_time = time.time()
- best_F1 = 0.
- Last_epoch = 0
- nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
- save_path = os.path.join("output", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
- os.makedirs(save_path)
- wandb.init(project=args.project_name, config=args.__dict__, name=nowtime, save_code=True)
- for epoch in range(args.start_epoch, args.epochs):
- # print_model_info(model)
- mean_loss, lr = train_one_epoch(
- model, optimizer, train_loader, device, epoch,
- lr_scheduler=lr_scheduler,
- print_freq=args.print_freq,
- num_classes=2,
- scaler=scaler, ema=model_ema)
-
- confmat = evaluate(model, val_loader,
- device=device,
- num_classes=num_classes, print_freq=args.print_freq)
- # ema.restore()
- val_info = ConfusionMatrix.todict(confmat)
- val_info_print = str(confmat)
- # 各种评价指标
- precision = float(val_info['precision'][1])
- average_row_correct = float(val_info['average row correct'][1])
- Iou = float(val_info['IoU'][1])
- recall = float(val_info['recall'][1])
- Avg_precision = val_info['Avg_precision']
- F1 = float(val_info['F1_Score'][1])
- mean_Iou = val_info['mean IoU']
-
- print(val_info_print)
- if F1 == "nan":
- F1 = 0
- else:
- F1 = float(F1)
- save_txt = os.path.join(save_path, results_file)
- print(save_txt)
- with open(save_txt, "a") as f:
- # 记录每个epoch对应的train_loss、lr以及验证集各指标
- train_info = f"[epoch: {epoch}]\n" \
- f"train_loss: {mean_loss:.4f}\n" \
- f"lr: {lr:.6f}\n"
-
- f.write(train_info + val_info_print + "\n\n")
- if F1 > best_F1:
- best_F1 = F1
- Last_epoch = epoch
- # 服务器保存模型地址
- model_name = "best.pth"
- save_url = os.path.join(save_path, model_name)
- print(save_url)
- torch.save(model, save_url)
- wandb.log(
- {'epoch': epoch, 'F1': F1, 'precision': precision, 'IoU': Iou, 'recall': recall, 'mean_Iou': mean_Iou,
- 'average_row_correct': average_row_correct, 'Avg_precision': Avg_precision, "lr": lr,
- "mean_loss": mean_loss, "best_F1": best_F1})
- print("best model in {} epoch".format(Last_epoch))
- total_time = time.time() - start_time
- total_time_str = str(datetime.timedelta(seconds=int(total_time)))
- print("training time {}".format(total_time_str))
- # save code
- arti_code = wandb.Artifact('python', type='code')
- arti_code.add_file('./utils/change_data_aug.py')
- arti_code.add_file('./utils/c_train_and_eval.py')
- arti_code.add_file('./utils/change_data.py')
- arti_code.add_file('train.py')
- wandb.log_artifact(arti_code)
- wandb.finish()
- return model
-
-
- if __name__ == '__main__':
- # 智算网络集群训练脚本自动化配置
- # args = parse_args()
- set_seed(args)
- wandb.login()
- sweep_config = {
- 'method': 'random'
- }
- metric = {
- 'name': 'best_F1',
- 'goal': 'maximize'
- }
- sweep_config['metric'] = metric
- sweep_config['parameters'] = {}
-
- # 固定不变的超参
- sweep_config['parameters'].update({
- 'project_name': {'value': 'wandb_demo'},
- 'epochs': {'value': 10},
- 'ckpt_path': {'value': 'checkpoint.pt'}})
-
- # 离散型分布超参
- sweep_config['parameters'].update({
- 'optim_type': {
- 'values': ['Adam', 'SGD', 'AdamW']
- },
- })
- # 连续型分布超参
- sweep_config['parameters'].update({
-
- 'lr': {
- 'distribution': 'log_uniform_values',
- 'min': 1e-6,
- 'max': 0.1
- },
- 'batch_size': {
- 'distribution': 'q_uniform',
- 'q': 8,
- 'min': 8,
- 'max': 64,
- },
- })
- sweep_id = wandb.sweep(sweep_config, project=args.project_name)
- wandb.agent(sweep_id, train, count=5)
- # train(args)
|