|
- import glob
- import os
- import time
- import datetime
- import cv2
- import torch
- from utils.b_train_and_eval import train_one_epoch, evaluate, create_lr_scheduler
- from utils.b_change_data import MyDataset
- # from models.b_efficientnetv2.efficientnetV2_s2 import EfficientNet_cd
- from models.b_efficientnetv2.b_RepefficientnetV2_s6 import EfficientNet_cd
- import warnings
- warnings.filterwarnings("ignore")
- os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
- def init_mask(path, size):
- mask1 = []
- mask2 = []
- img_path_A = glob.glob(os.path.join(path, 'A', '*.png'))
- img_path_B = glob.glob(os.path.join(path, 'B', '*.png'))
- for i in range(len(img_path_A)):
- img1 = cv2.resize(cv2.imread(img_path_A[i], 1), (size, size), interpolation=cv2.INTER_CUBIC)
- mask1.append(cv2.Canny(img1, 100, 200, 5) / 255)
-
- img2 = cv2.resize(cv2.imread(img_path_B[i], 1), (size, size), interpolation=cv2.INTER_CUBIC)
- mask2.append(cv2.Canny(img2, 100, 200, 5) / 255)
-
- return mask1, mask2
-
-
- def main(args):
- device = torch.device(args.device if torch.cuda.is_available() else "cpu")
-
- batch_size = args.batch_size
-
- num_classes = args.num_classes + 1
-
- # 用来保存训练以及验证过程中信息
- results_file = "results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
- train_dataset = MyDataset(args.data_path)
-
- val_dataset = MyDataset(args.val_path)
-
- num_workers = 8
- train_loader = torch.utils.data.DataLoader(train_dataset,
- batch_size=batch_size,
- num_workers=num_workers,
- shuffle=True,
- pin_memory=True
- )
-
- val_loader = torch.utils.data.DataLoader(val_dataset,
- batch_size=batch_size,
- num_workers=num_workers,
- pin_memory=True
- )
-
- # model = create_model(device)
- model = EfficientNet_cd(num_classes=2)
- # model = SNUNet_ECAM(16, 3, 2, "bilinear")
- model.to(device)
- # 优化器
- # optimizer = torch.optim.SGD(
- # model.parameters(),
- # lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
- # )
- optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
- # optimizer = torch.optim.Adam(model.parameters(), lr=0.0002)
- # 学习率设置
- # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0, T_max=25, last_epoch=-1)
- lr_scheduler = create_lr_scheduler(optimizer, len(train_loader), args.epochs,
- warmup=True, warmup_epochs=2)
- scaler = torch.cuda.amp.GradScaler() if args.amp else None
- # lr_scheduler = create_lr_scheduler(optimizer, len(train_loader), args.epochs, warmup=True)
- print(args.ckpt_url)
- if args.ckpt_url:
- checkpoint = torch.load(args.ckpt_url, map_location='cpu')
- model.load_state_dict(checkpoint['model'])
- # 是否继续训练
- if args.resume:
- checkpoint = torch.load(args.resume, map_location='cpu')
- model.load_state_dict(checkpoint['model'])
- optimizer.load_state_dict(checkpoint['optimizer'])
- lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
- args.start_epoch = checkpoint['epoch'] + 1
- # 混合精度训练
- if args.amp:
- scaler.load_state_dict(checkpoint["scaler"])
-
- # 开始时间
- start_time = time.time()
- best_pre = 0.
- Last_epoch = 0
- train_pre_mask1, train_pre_mask2 = init_mask(args.data_path, 1024)
- val_pre_mask1, val_pre_mask2 = init_mask(args.data_path, 1024)
- for epoch in range(args.start_epoch, args.epochs):
- # print_model_info(model)
- mean_loss, lr, b_images1_mask, b_images2_mask = train_one_epoch(model, optimizer, train_loader, device, epoch,
- train_pre_mask1, train_pre_mask2,
- lr_scheduler=lr_scheduler,
- print_freq=args.print_freq,
- num_classes=2,
- scaler=scaler)
-
- confmat, b_val_iamges_mask1, b_val_iamges_mask2 = evaluate(model, val_loader, val_pre_mask1, val_pre_mask2,
- device=device,
- num_classes=num_classes)
- val_info = str(confmat)
- val_f1 = val_info.split("\n")[5].split(":")[1].split(",")[1].split("]")[0][2:-1]
- print(val_info)
- if val_f1 == "nan":
- val_f1 = 0
- else:
- val_f1 = float(val_f1)
- save_txt = os.path.join(os.path.abspath(os.path.join(os.getcwd(), "../..")), "output", results_file)
- # 使用启智集群的url
- # save_txt = os.path.join(os.path.abspath(os.path.join(os.getcwd(), "../")), "model", results_file)
- print(save_txt)
- with open(save_txt, "a") as f:
- # 记录每个epoch对应的train_loss、lr以及验证集各指标
- train_info = f"[epoch: {epoch}]\n" \
- f"train_loss: {mean_loss:.4f}\n" \
- f"lr: {lr:.6f}\n"
-
- f.write(train_info + val_info + "\n\n")
-
- save_file = {"model": model.state_dict(),
- "optimizer": optimizer.state_dict(),
- "lr_scheduler": lr_scheduler.state_dict(),
- "epoch": epoch,
- "args": args}
- if args.amp:
- save_file["scaler"] = scaler.state_dict()
- # torch.save(save_file, "output/model_{}.pth".format(epoch))
- if val_f1 > best_pre:
- for i in range(len(b_images2_mask)):
- b_images2_mask[i] = b_images2_mask[i].detach().numpy()
- b_images1_mask[i] = b_images1_mask[i].detach().numpy()
-
- for i in range(len(b_val_iamges_mask1)):
- b_val_iamges_mask1[i] = b_val_iamges_mask1[i].detach().numpy()
- b_val_iamges_mask2[i] = b_val_iamges_mask2[i].detach().numpy()
-
- train_pre_mask1, train_pre_mask2 = b_images1_mask, b_images2_mask
- val_pre_mask1, val_pre_mask2 = b_val_iamges_mask1, b_val_iamges_mask2
-
- best_pre = val_f1
- Last_epoch = epoch
- # 服务器保存模型地址
- model_name = "best.pth"
- # save_url = os.path.join(os.path.abspath(os.path.join(os.getcwd(), "../..")), "output", model_name)
- save_url = os.path.join(args.out_path, model_name)
- print(save_url)
- torch.save(save_file, save_url)
- with open(save1_txt, "a") as f:
- # 记录每个epoch对应的train_loss、lr以及验证集各指标
- f.write(str(all_loss))
- print("best model in {} epoch".format(Last_epoch))
- total_time = time.time() - start_time
- total_time_str = str(datetime.timedelta(seconds=int(total_time)))
- print("training time {}".format(total_time_str))
-
- def parse_args():
- import argparse
- parser = argparse.ArgumentParser(description="pytorch fcn training")
- parser.add_argument("--ckpt_url", default="", help="data root")
- parser.add_argument("--data_path", default="/tmp/dataset/train", help="data root")
- parser.add_argument("--val_path", default="/tmp/dataset/val", help="val root")
- parser.add_argument("--out_path", default="/tmp/output", help="val root")
- parser.add_argument("--num-classes", default=1, type=int)
- parser.add_argument("--device", default="cuda", help="training device")
- parser.add_argument("-b", "--batch-size", default=1, type=int)
- parser.add_argument("--epochs", default=200, type=int, metavar="N",
- help="number of total epochs to train")
-
- parser.add_argument('--lr', default=0.001, type=float, help='initial learning rate')
- parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
- help='momentum')
- parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
- metavar='W', help='weight decay (default: 1e-4)',
- dest='weight_decay')
- parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
- parser.add_argument('--resume', default='', help='resume from checkpoint')
- parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
- help='start epoch')
- # Mixed precision training parameters
- parser.add_argument("--amp", default=True, type=bool,
- help="Use torch.cuda.amp for mixed precision training")
-
- args = parser.parse_args()
- return args
-
- if __name__ == '__main__':
- # 智算网络集群训练脚本自动化配置
-
- args = parse_args()
- if not os.path.exists("output"):
- os.mkdir("output")
-
- main(args)
|