|
- '''
- 对比学习预训练
- '''
-
- import os
- import torch
- import clip
- import tqdm
- import logging
- import time
- import argparse
- from torchvision import transforms
- from torch import nn, optim
- from torch.utils.data import DataLoader
- from dataset import PatchDataset
- from tensorboardX import SummaryWriter
- from utils import train_epoch_pre, valid_epoch_pre, test_pre
- from model import CL_baseline, CLIP_WSI_Model
- from loss import *
-
- def run(args):
- # 数据增强
- normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- train_transformer = transforms.Compose([
- transforms.Resize((224, 224)),
- transforms.RandomAffine(0, shear=0.2, scale=(0.8, 1.2)), # 随机仿射变换
- transforms.RandomHorizontalFlip(), # 随机水平翻转
- transforms.RandomVerticalFlip(), # 随机垂直翻转
- transforms.ColorJitter(64.0 / 255, 0.75, 0.25, 0.04),
- transforms.ToTensor(),
- normalize
- ])
- val_transformer = transforms.Compose([
- transforms.Resize((224, 224)),
- transforms.ToTensor(),
- normalize
- ])
-
- # 设置设备
- args.device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
- # 输出实验参数
- print(args, flush=True)
-
- # 创建数据集
- print('===> Building dataset...')
- tumor_train_dataset = PatchDataset(transform=train_transformer,
- dataset_name=args.dataset,
- mode='tumor_train')
- tumor_valid_dataset = PatchDataset(transform=val_transformer,
- dataset_name=args.dataset,
- mode='tumor_val')
- normal_train_dataset = PatchDataset(transform=train_transformer,
- dataset_name=args.dataset,
- mode='normal_train')
- normal_valid_dataset = PatchDataset(transform=val_transformer,
- dataset_name=args.dataset,
- mode='normal_val')
-
- tumor_train_dataloader = DataLoader(tumor_train_dataset,
- batch_size=args.batch_size,
- shuffle=True)
- tumor_val_dataloader = DataLoader(tumor_valid_dataset,
- batch_size=args.batch_size,
- shuffle=False)
- normal_train_dataloader = DataLoader(normal_train_dataset,
- batch_size=args.batch_size,
- shuffle=True)
- normal_val_dataloader = DataLoader(normal_valid_dataset,
- batch_size=args.batch_size,
- shuffle=False)
-
- # 创建模型
- print('===> Building model...')
- if args.clip:
- model = CLIP_WSI_Model(vision_backbone=args.vb,
- linearProbe=False)
- else:
- model = CL_baseline(backbone=args.vb,
- pretrain='None',
- linearProbe=False)
-
- model.to(args.device)
- # 设置损失函数
- loss_fn = [SimMaxLoss(metric='cos', alpha=args.alpha).cuda(), SimMinLoss(metric='cos').cuda(),
- SimMaxLoss(metric='cos', alpha=args.alpha).cuda()]
-
- # 设置优化器
- # optimizer = optim.Adam(model.parameters(), lr=args.lr)
- optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
-
- # 初始化summary内容
- summary_train = {'epoch': 0, 'step': 0, 'train_loss1': float('inf'), 'train_loss2': float('inf'), 'train_loss3': float('inf'), 'train_loss': float('inf'), 'train_acc': 0}
- summary_valid = {'val_loss': float('inf'), 'val_acc': 0}
- summary_test = {'test_acc': 0, 'test_f1': 0, 'test_auc': 0,
- 'test_precision': 0, 'test_recall': 0, 'test_confusion_matrix': ''}
- summary_writer = SummaryWriter(args.save_path)
- loss_valid_best = float('inf')
- acc_valid_best = 0
-
- # 开始训练
- print('===> Start training')
- # text = 'tumor cell nucleus larger, morphology different'
- text = 'tumor cell features vary and larger than normal cell'
- # text = '癌细胞比正常细胞大,癌细胞互相之间形态大小也不一致'
- text = clip.tokenize(text).to(args.device)
- for epoch in range(args.start_epoch, args.epochs):
- # 训练一轮
- summary_train = train_epoch_pre(summary_train, summary_writer, args, model,
- loss_fn, optimizer, tumor_train_dataloader,
- normal_train_dataloader, text)
-
- # 保存此轮训练的相关信息
- torch.save(
- {'epoch': summary_train['epoch'],
- 'train_loss1': summary_train['train_loss1'],
- 'train_loss2': summary_train['train_loss2'],
- 'train_loss3': summary_train['train_loss3'],
- 'train_loss': summary_train['train_loss'],
- 'state_dict': model.state_dict()},
- os.path.join(args.save_path, 'train.ckpt')
- )
-
- if args.clip:
- # 记录当前时间
- time_now = time.time()
- # 训练一轮后进行一次验证
- print('===> Epoch:{} validation...'.format(epoch))
- summary_valid = valid_epoch_pre(summary_valid, args, model, loss_fn,
- tumor_val_dataloader, normal_val_dataloader,
- text)
- # 记录验证耗费的时间
- time_spent = time.time() - time_now
-
- # 记录下相关信息,并打印出来
- logging.info(
- '{}, Epoch: {}, Step: {}, Validation Acc: {:.3f}, Run Time: {:.2f}'
- .format(
- time.strftime("%Y-%m-%d %H:%M:%S"), summary_train['epoch'],
- summary_train['step'] + 1, summary_valid['val_acc'], time_spent
- )
- )
-
- # 将验证损失和验证精度添加到summary_writer
- summary_writer.add_scalar('valid/acc', summary_valid['val_acc'], summary_train['step'])
-
- # 保存验证损失最小的模型及相关的信息
- if summary_valid['val_acc'] >= acc_valid_best:
- acc_valid_best = summary_valid['val_acc']
- torch.save(
- {'epoch': summary_train['epoch'],
- 'step': summary_train['step'],
- 'state_dict': model.state_dict()},
- os.path.join(args.save_path, 'best.ckpt')
- )
-
- summary_writer.close()
-
- if args.clip:
- # 测试
- print('===> Testing...')
- checkpoint = torch.load(os.path.join(args.save_path, 'best.ckpt'))
- model.load_state_dict(checkpoint['state_dict'])
- time_now = time.time()
- summary_test = test_pre(summary_test, args, model, loss_fn,
- tumor_val_dataloader, normal_val_dataloader,
- text)
- time_spent = time.time() - time_now
- # 记录下相关信息,并打印出来
- logging.info(
- '{}, test_acc: {:.4f}, test_f1: {:.4f}, test_auc: {:.4f}, test_precision: {:.4f}, '
- 'test_recall: {:.4f}, confusion_matrix: {}, Run Time: {:.2f}'
- .format(
- time.strftime("%Y-%m-%d %H:%M:%S"), summary_test['test_acc'], summary_test['test_f1'],
- summary_test['test_auc'], summary_test['test_precision'], summary_test['test_recall'],
- summary_test['test_confusion_matrix'], time_spent
- )
- )
-
-
- # parameters setting
- def parse_args():
- parser = argparse.ArgumentParser(description='CL pretrain parameters')
- # 视觉主干网络
- parser.add_argument('--vision_backbone', dest='vb', default='RN50', type=str,
- help='[RN50, RN101, RN50x4, RN50x16, RN50x64, ViT-B/32, ViT-B/16, ViT-L/14, ViT-L/14@336px]')
- # 是否用CLIP训练,区分对比学习和CLIP
- parser.add_argument('--clip', dest='clip', default=False, type=bool)
- # 数据集名称
- parser.add_argument('--dataset', dest='dataset', default='NCRF', type=str,
- help='[NCRF, PCAM]')
- parser.add_argument('--epochs', dest='epochs', default=20, type=int)
- parser.add_argument('--batch_size', dest='batch_size', default=32, type=int)
- parser.add_argument('--base_lr', dest='lr', default=0.01, type=float)
- parser.add_argument('--start_epoch', dest='start_epoch', default=0, type=int)
- # 对比学习的预训练权重
- parser.add_argument('--_pretrain', default='None', type=str,
- help='None, imagenet, CLIP')
- # 对比学习和CLIP学习的模型保存路径
- parser.add_argument('--save_path', dest='save_path', default='/model/PRETRAIN_SAVE_PATH', type=str)
- # 对比损失参数
- parser.add_argument('--alpha', default=0.05, type=float)
- # 设备名称
- parser.add_argument('--device', default='cuda:0', type=str)
- args = parser.parse_args()
- return args
-
- if __name__ == '__main__':
- logging.basicConfig(level=logging.INFO)
- args = parse_args()
-
- run(args)
|