|
- #!/usr/bin/env python
- # coding: utf-8
-
- # # train
-
- # In[1]:
-
-
- # get_ipython().system('pwd')
-
-
- # # import
-
- # In[2]:
-
-
- import io
- import os
- import sys
- import time
- import json
- import math
- import datetime
- import argparse
- from tqdm import tqdm
- import numpy as np
- from pathlib import Path
- from collections import defaultdict, deque
- import glob
- import gc
- import timm
- import torch
- import torch.distributed as dist
- from torch import nn, einsum
- from torch.nn import functional as F
- import torch.backends.cudnn as cudnn
- import torch.utils.checkpoint as checkpoint
- import datetime
-
- from typing import Iterable, Optional
- from timm.models import create_model
- from timm.optim import create_optimizer
- from timm.scheduler import create_scheduler
- from timm.data import Mixup,create_transform
- from timm.models.registry import register_model
- from timm.models.layers import DropPath, trunc_normal_
- from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
- from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
- from timm.utils import accuracy, ModelEma,NativeScaler, get_state_dict, ModelEma
-
- from torchvision import datasets, transforms
- from torchvision.datasets.folder import ImageFolder, default_loader
- from torch.optim.lr_scheduler import _LRScheduler
- from functools import partial
- from einops import rearrange
-
- import math
- import torch
- from torch import nn
- from torchvision import models
- # from torchsummary import summary
- import torchvision
-
-
- # In[3]:
-
-
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- device
-
-
- # In[4]:
-
-
- print(torch.__version__)
-
-
- # ## utils
-
- # In[5]:
-
-
- def is_dist_avail_and_initialized():
- if not dist.is_available():
- return False
- if not dist.is_initialized():
- return False
- return True
-
-
- # In[6]:
-
-
- def init_distributed_mode(args):
- if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
- args.rank = int(os.environ["RANK"])
- args.world_size = int(os.environ['WORLD_SIZE'])
- args.gpu = int(os.environ['LOCAL_RANK'])
- elif 'SLURM_PROCID' in os.environ:
- args.rank = int(os.environ['SLURM_PROCID'])
- args.gpu = args.rank % torch.cuda.device_count()
- else:
- print('Not using distributed mode')
- args.distributed = False
- return
-
- args.distributed = True
-
- torch.cuda.set_device(args.gpu)
- args.dist_backend = 'nccl'
- print('| distributed init (rank {}): {}'.format(
- args.rank, args.dist_url), flush=True)
- torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
- world_size=args.world_size, rank=args.rank)
- torch.distributed.barrier()
- setup_for_distributed(args.rank == 0)
-
-
- # In[7]:
-
-
- def get_world_size():
- if not is_dist_avail_and_initialized():
- return 1
- return dist.get_world_size()
-
-
- # In[8]:
-
-
- def cal_flops_params_with_fvcore(model, inputs):
- """
- print model's flops
- """
- from fvcore.nn import FlopCountAnalysis, parameter_count_table, parameter_count
- flops = FlopCountAnalysis(model, inputs)
- params = parameter_count(model)
- print('flops(fvcore): %f M' % (flops.total()/1000**2))
- print('params(fvcore): %f M' % (params['']/1000**2))
-
-
- # In[9]:
-
-
- @torch.no_grad()
- def concat_all_gather(tensor):
- """
- Performs all_gather operation on the provided tensors.
- *** Warning ***: torch.distributed.all_gather has no gradient.
- """
- tensors_gather = [
- torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
- ]
- torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
-
- output = torch.cat(tensors_gather, dim=0)
- return output
-
-
- # In[10]:
-
-
- def save_on_master(*args, **kwargs):
- """
- save model
- """
- # if is_main_process():
- torch.save(*args, **kwargs)
-
-
- # In[11]:
-
-
- def get_rank():
- """
- get gpu's num
- """
- if not is_dist_avail_and_initialized():
- return 0
- return dist.get_rank()
-
-
- # In[12]:
-
-
- def set_seed(seed = 1234):
- '''Sets the seed of the entire notebook so results are the same every time we run.
- This is for REPRODUCIBILITY.'''
- np.random.seed(seed)
- random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed(seed)
- # When running on the CuDNN backend, two further options must be set
- torch.backends.cudnn.deterministic = True
- # Set a fixed value for the hash seed
- os.environ['PYTHONHASHSEED'] = str(seed)
-
-
- # In[13]:
-
-
- class WarmUpLR(_LRScheduler):
- """warmup_training learning rate scheduler
- Args:
- optimizer: optimzier(e.g. SGD)
- total_iters: totoal_iters of warmup phase
- """
- def __init__(self, optimizer, total_iters, last_epoch=-1):
-
- self.total_iters = total_iters
- super().__init__(optimizer, last_epoch)
-
- def get_lr(self):
- """we will use the first m batches, and set the learning
- rate to base_lr * m / total_iters
- """
- return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]
-
-
- # In[14]:
-
-
- def build_dataset(is_train, args):
- transform = build_transform(is_train, args)
- if not args.use_mcloader:
- root = os.path.join(args.data_path, 'train' if is_train else 'val')
- dataset = datasets.ImageFolder(root, transform=transform)
- else:
- from mcloader import ClassificationDataset
- dataset = ClassificationDataset(
- 'train' if is_train else 'val',
- pipeline=transform
- )
- nb_classes = 2
-
- return dataset, nb_classes
-
-
- def build_transform(is_train, args):
- """
- is_train
- """
- # 单独设置
-
- # 随机改变图像的色调
- # hue_change = transforms.ColorJitter(hue=0.5)
-
- if is_train:
- # 随机改变图像的亮度
- brightness_change = transforms.ColorJitter(brightness=0.5)
- # 随机改变图像的对比度
- contrast_change = transforms.ColorJitter(contrast=0.5)
- train_transform = transforms.Compose([
- transforms.ToTensor(), # 图片转张量,同时归一化0-255 ---》 0-1
- transforms.Normalize([0.22041939,0.22041939, 0.22041939], [0.19473027 ,0.19473027, 0.19473027]), # 标准化均值为0标准差为1
- brightness_change,
- contrast_change,
- transforms.RandomHorizontalFlip(p=0.5)#随机水平翻转
-
- ])
- else:
- train_transform = transforms.Compose([
- transforms.ToTensor(), # 图片转张量,同时归一化0-255 ---》 0-1
- transforms.Normalize([0.22041939, 0.22041939, 0.22041939], [0.19473027 ,0.19473027 ,0.19473027]), # 标准化均值为0标准差为1
- # brightness_change,
- # contrast_change,
- transforms.RandomHorizontalFlip(p=0.5)#随机水平翻转
- ])
-
- return train_transform
-
-
- # In[15]:
-
-
- def build_weight(path='',nu_class=1000):
- """
- set
- """
- ud_map={}
- trian_path = glob.glob(path+'/train/*/*')
- print()
- for i in trian_path:
- i = i.split('/')
- if i[-2] in ud_map:
- ud_map[i[-2]]+=1
- else:
- ud_map.update({i[-2]:1})
- tol = len(trian_path)
- weight = []
- for k in ud_map:
- for i in range(ud_map[k]):
- weight.append(tol/ud_map[k])
- return weight
-
-
- # In[16]:
-
-
- class SmoothedValue(object):
- """Track a series of values and provide access to smoothed values over a
- window or the global series average.
- """
-
- def __init__(self, window_size=20, fmt=None):
- if fmt is None:
- fmt = "{median:.4f} ({global_avg:.4f})"
- self.deque = deque(maxlen=window_size)
- self.total = 0.0
- self.count = 0
- self.fmt = fmt
-
- def update(self, value, n=1):
- self.deque.append(value)
- self.count += n
- self.total += value * n
-
- def synchronize_between_processes(self):
- """
- Warning: does not synchronize the deque!
- """
- if not is_dist_avail_and_initialized():
- return
- t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
- dist.barrier()
- dist.all_reduce(t)
- t = t.tolist()
- self.count = int(t[0])
- self.total = t[1]
-
- @property
- def median(self):
- d = torch.tensor(list(self.deque))
- return d.median().item()
-
- @property
- def avg(self):
- d = torch.tensor(list(self.deque), dtype=torch.float32)
- return d.mean().item()
-
- @property
- def global_avg(self):
- return self.total / self.count
-
- @property
- def max(self):
- return max(self.deque)
-
- @property
- def value(self):
- return self.deque[-1]
-
- def __str__(self):
- return self.fmt.format(
- median=self.median,
- avg=self.avg,
- global_avg=self.global_avg,
- max=self.max,
- value=self.value)
-
- class MetricLogger(object):
- def __init__(self, delimiter="\t"):
- self.meters = defaultdict(SmoothedValue)
- self.delimiter = delimiter
-
- def update(self, **kwargs):
- for k, v in kwargs.items():
- if isinstance(v, torch.Tensor):
- v = v.item()
- assert isinstance(v, (float, int))
- self.meters[k].update(v)
-
- def __getattr__(self, attr):
- if attr in self.meters:
- return self.meters[attr]
- if attr in self.__dict__:
- return self.__dict__[attr]
- raise AttributeError("'{}' object has no attribute '{}'".format(
- type(self).__name__, attr))
-
- def __str__(self):
- loss_str = []
- for name, meter in self.meters.items():
- loss_str.append(
- "{}: {}".format(name, str(meter))
- )
- return self.delimiter.join(loss_str)
-
- def synchronize_between_processes(self):
- for meter in self.meters.values():
- meter.synchronize_between_processes()
-
- def add_meter(self, name, meter):
- self.meters[name] = meter
-
- def log_every(self, iterable, print_freq, header=None):
- i = 0
- if not header:
- header = ''
- start_time = time.time()
- end = time.time()
- iter_time = SmoothedValue(fmt='{avg:.4f}')
- data_time = SmoothedValue(fmt='{avg:.4f}')
- space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
- log_msg = [
- header,
- '[{0' + space_fmt + '}/{1}]',
- 'eta: {eta}',
- '{meters}',
- 'time: {time}',
- 'data: {data}'
- ]
- if torch.cuda.is_available():
- log_msg.append('max mem: {memory:.0f}')
- log_msg = self.delimiter.join(log_msg)
- MB = 1024.0 * 1024.0
- for obj in iterable:
- data_time.update(time.time() - end)
- yield obj
- iter_time.update(time.time() - end)
- if i % print_freq == 0 or i == len(iterable) - 1:
- eta_seconds = iter_time.global_avg * (len(iterable) - i)
- eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
- if torch.cuda.is_available():
- print(log_msg.format(
- i, len(iterable), eta=eta_string,
- meters=str(self),
- time=str(iter_time), data=str(data_time),
- memory=torch.cuda.max_memory_allocated() / MB))
- else:
- print(log_msg.format(
- i, len(iterable), eta=eta_string,
- meters=str(self),
- time=str(iter_time), data=str(data_time)))
- i += 1
- end = time.time()
- total_time = time.time() - start_time
- total_time_str = str(datetime.timedelta(seconds=int(total_time)))
- print('{} Total time: {} ({:.4f} s / it)'.format(
- header, total_time_str, total_time / len(iterable)))
-
-
- # In[17]:
-
-
- def sumloss(loss_cls ,nce_loss ,sim_loss ,aux_loss1):
-
- return loss_cls+nce_loss+sim_loss[0]+sim_loss[0]+aux_loss1
-
-
- # # train&&eval
-
- # In[18]:
-
-
- # def train_one_epoch(model,class_criterion,nce_criterion,
- # data_loader, optimizer: torch.optim.Optimizer,
- # device: torch.device, epoch: int, loss_scaler,#,warmup,warmupstep:int,
- # max_norm: float = 0,
- # model_ema: Optional[ModelEma] = None,# mixup_fn: Optional[Mixup] = None,
- # set_training_mode=True):
- # print("1:{}".format(torch.cuda.memory_allocated(0)))
- # model.train(set_training_mode)
- # print("2:{}".format(torch.cuda.memory_allocated(0)))
- # metric_logger = MetricLogger(delimiter=" ")
- # print("3:{}".format(torch.cuda.memory_allocated(0)))
- # metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
- # header = 'Epoch: [{}]'.format(epoch)
- # print_freq = 10
- # print("4:{}".format(torch.cuda.memory_allocated(0)))
- # # if epoch < warmupstep:
- # # warmup_scheduler.step()
- # loss_cls1 = 0
- # nce_loss1= 0
- # sim_loss1= 0
- # aux_loss11= 0
- # aux_loss12= 0
- # aux_loss13= 0
- # print("5:{}".format(torch.cuda.memory_allocated(0)))
- # for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
- # print("6:{}".format(torch.cuda.memory_allocated(0)))
- # samples = samples.to(device, non_blocking=True)
- # targets = targets.to(device, non_blocking=True)
- # print(targets)
- # # print(targets.shape)
- # # tar1 = targets.float()
- # # if mixup_fn is not None:
- # # samples, targets = mixup_fn(samples, targets)
- # print("7:{}".format(torch.cuda.memory_allocated(0)))
- # with torch.cuda.amp.autocast():
- # cl,log,lab,fe,attn = model(samples)
- # # cl = model(samples)
- # # print(cl)
- # # print(cl.shape)
- # # cl =
- # # print(lab.shape)、
- # print("8:{}".format(torch.cuda.memory_allocated(0)))
- # loss_cls = class_criterion(cl,targets.float())
- # loss_clsvalue = loss_cls.item()
- # nce_loss = nce_criterion(log,lab.float())
- # print("9:{}".format(torch.cuda.memory_allocated(0)))
- # loss_ncevalue = nce_loss.item()
- # sim_loss = similiarity(attn)
- # aux_loss1= aux_similiarity(fe[0][0],fe[3][0])
- # aux_loss2= aux_similiarity(fe[1][0],fe[4][0])
- # aux_loss3= aux_similiarity(fe[2][0],fe[5][0])
- # print("10:{}".format(torch.cuda.memory_allocated(0)))
- # # print(loss_cls ,nce_loss ,sim_loss ,aux_loss1,aux_loss2,aux_loss3)
- # #cancer,logits,labels,features,attn
- # loss_value = sumloss(loss_cls ,nce_loss ,sim_loss ,aux_loss1,aux_loss2,aux_loss3)
- # # loss_value = loss_clsvalue
- # print(loss_value)
- # print("11:{}".format(torch.cuda.memory_allocated(0)))
- # # loss_cls1 += loss_cls
- # # nce_loss1+= nce_loss
- # # sim_loss1+= sim_loss
- # # aux_loss11+= aux_loss1
- # # aux_loss12+= aux_loss2
- # # aux_loss13+= aux_loss3
- # gc.collect()
- # # loss_value = loss_cls.item()
- # print("12:{}".format(torch.cuda.memory_allocated(0)))
- # if not math.isfinite(loss_value):
- # print("Loss is {}, stopping training".format(loss_value))
- # sys.exit(1)
- # print("13:{}".format(torch.cuda.memory_allocated(0)))
- # optimizer.zero_grad()
-
- # # this attribute is added by timm on one optimizer (adahessian)
- # print("14:{}".format(torch.cuda.memory_allocated(0)))
- # is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
- # print("15:{}".format(torch.cuda.memory_allocated(0)))
- # loss_scaler(loss_value, optimizer, clip_grad=max_norm,
- # parameters=model.parameters(), create_graph=is_second_order)
-
- # torch.cuda.synchronize()
- # if model_ema is not None:
- # model_ema.update(model)
- # print("16:{}".format(torch.cuda.memory_allocated(0)))
- # metric_logger.update(loss=loss_value)
- # metric_logger.update(lr=optimizer.param_groups[0]["lr"])
- # gc.collect()
- # # gather the stats from all processes
- # metric_logger.synchronize_between_processes()
- # print(loss_cls1 ,nce_loss1 ,sim_loss1 ,aux_loss11,aux_loss12,aux_loss13)
- # print("Averaged stats:", metric_logger)
- # return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
-
- # def train_one_epoch(model, train_loader, optimizer):
- # model.train()
- # scaler = torch.cuda.amp.GradScaler()
- # losses_all, bce_all, tverskly_all = 0, 0, 0
- # class_criterion = nn.BCEWithLogitsLoss()
- # nce_criterion = nn.CrossEntropyLoss()
- # pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc='Train ')
- # for _, (images, lab) in pbar:
- # optimizer.zero_grad()
- # images = images.to(device, non_blocking=True)
- # lab = lab.to(device, non_blocking=True)
- # with torch.cuda.amp.autocast(enabled=True):
- # cl,log,lab,fe,attn = model(images)
- # print(cl.shape)
- # print(lab.shape)
- # loss_cls = class_criterion(cl,lab.float())
- # loss_clsvalue = loss_cls.item()
- # # nce_loss = nce_criterion(log,lab.float())
- # # print("9:{}".format(torch.cuda.memory_allocated(0)))
- # # loss_ncevalue = nce_loss.item()
- # sim_loss = similiarity(attn)
- # aux_loss1= aux_similiarity(fe[0][0],fe[3][0])
- # aux_loss2= aux_similiarity(fe[1][0],fe[4][0])
- # aux_loss3= aux_similiarity(fe[2][0],fe[5][0])
- # # print("10:{}".format(torch.cuda.memory_allocated(0)))
- # # # print(loss_cls ,nce_loss ,sim_loss ,aux_loss1,aux_loss2,aux_loss3)
- # # #cancer,logits,labels,features,attn
- # loss_value = sumloss(loss_cls ,nce_loss ,sim_loss ,aux_loss1,aux_loss2,aux_loss3)
- # # loss_value = loss_clsvalue
- # # print(loss_value)
- # # print("11:{}".format(torch.cuda.memory_allocated(0)))
-
- # loss_cls1 += loss_cls.item()
- # nce_loss1+= nce_loss.item()
- # sim_loss1+= sim_loss.item()
- # aux_loss11+= aux_loss1.item()
- # aux_loss12+= aux_loss2.item()
- # aux_loss13+= aux_loss3.item()
-
- # # bce_loss = 0.5 * losses_dict["BCELoss"](y_preds, masks)
- # # tverskly_loss = 0.5 * losses_dict["TverskyLoss"](y_preds, masks)
- # # losses = bce_loss + tverskly_loss
-
- # scaler.scale(losses).backward()
- # scaler.step(optimizer)
- # scaler.update()
- # gc.collect()
- # # losses_all += losses.item() / images.shape[0]
- # # bce_all += bce_loss.item() / images.shape[0]
- # # tverskly_all += tverskly_loss.item() / images.shape[0]
-
- # current_lr = optimizer.param_groups[0]['lr']
- # print("lr: {:.4f}".format(current_lr), flush=True)
- # print("cls: {:.3f}, ncd: {:.3f}, sim: {:.3f},sim1: {:.3f}, sim2: {:.3f}, sim3: {:.3f}"
- # .format(loss_cls1, nce_loss1, sim_loss1,aux_loss11,aux_loss12,aux_loss13), flush=True)
- # gc.collect()
-
-
- # In[19]:
-
-
- def evaluate(data_loader, model, device):
- criterion = torch.nn.BCEWithLogitsLoss()
-
- metric_logger = MetricLogger(delimiter=" ")
- header = 'Test:'
-
- # switch to evaluation mode
- model.eval()
-
- for images, target in metric_logger.log_every(data_loader, 10, header):
- images = images.to(device, non_blocking=True)
- target = target.to(device, non_blocking=True)
-
- # compute output
- with torch.cuda.amp.autocast():
- cancer,can1,can2,can3= model(samples)
- loss_cls = class_criterion(cancer,targets.float())
- loss_clsvalue = loss_cls.item()
- loss_cls1 = class_criterion(can1,targets.float())
- loss_clsvalue = loss_cls.item()
- loss_cls2 = class_criterion(can2,targets.float())
- loss_clsvalue = loss_cls.item()
- loss_cls3 = class_criterion(can3,targets.float())
- loss_clsvalue = loss_cls.item()
-
- acc1,_ = accuracy(output, target, topk=(1,))
-
- batch_size = images.shape[0]
- metric_logger.update(loss=loss.item())
- metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
- # metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
- # gather the stats from all processes
- metric_logger.synchronize_between_processes()
- print('* Acc@1 {top1.global_avg:.3f} loss {losses.global_avg:.3f}'
- .format(top1=metric_logger.acc1, losses=metric_logger.loss))
-
- return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
-
-
- # ## 吞吐量
-
- # In[20]:
-
-
- @torch.no_grad()
- def throughput(data_loader, model, logger):
- model.eval()
-
- for idx, (images, _) in enumerate(data_loader):
- images = images.cuda(non_blocking=True)
- batch_size = images.shape[0]
- for i in range(50):
- model(images)
- torch.cuda.synchronize()
- logger.info(f"throughput averaged with 30 times")
- tic1 = time.time()
- for i in range(30):
- model(images)
- torch.cuda.synchronize()
- tic2 = time.time()
- logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
- return
-
-
- # # build
-
- # ## embeding patch
-
- # In[21]:
-
-
- NORM_EPS = 1e-5
- class PatchEmbed(nn.Module):
- def __init__(self,
- in_channels,
- out_channels,
- stride=1,
- mode = "V"
- ):
- super(PatchEmbed, self).__init__()
- norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS)
- if stride == 4 and mode == "V":
- self.avgpool = nn.AvgPool2d((4, 32), stride=4, ceil_mode=True, count_include_pad=False)
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
- self.norm = norm_layer(out_channels)
- elif stride == 2 and mode == "V":
- self.avgpool = nn.AvgPool2d((2, 16), stride=2, ceil_mode=True, count_include_pad=False)
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
- self.norm = norm_layer(out_channels)
- elif stride == 2 and mode == "H":
- self.avgpool = nn.AvgPool2d((64, 2), stride=2, ceil_mode=True, count_include_pad=False)
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
- self.norm = norm_layer(out_channels)
- elif stride == 1 and mode == "H":
- self.avgpool = nn.AvgPool2d((32, 1), stride=1, ceil_mode=True, count_include_pad=False)
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
- self.norm = norm_layer(out_channels)
- elif in_channels != out_channels:
- self.avgpool = nn.Identity()
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
- self.norm = norm_layer(out_channels)
- else:
- self.avgpool = nn.Identity()
- self.conv = nn.Identity()
- self.norm = nn.Identity()
-
- def forward(self, x):
- return self.norm(self.conv(self.avgpool(x)))
-
-
- # ## product QKV
-
- # In[22]:
-
-
- class E_PQKV(nn.Module):
- """
- product QKV
- """
- def __init__(self, dim, head_dim=32, out_dim=None, qkv_bias=True, qk_scale=None,
- attn_drop=0, proj_drop=0., sr_ratio=1):
- super().__init__()
- self.dim = dim
- self.out_dim = out_dim if out_dim is not None else dim
- self.num_heads = self.dim // head_dim
- self.scale = qk_scale or head_dim ** -0.5
- self.q = nn.Linear(dim, self.dim, bias=qkv_bias)
- self.k = nn.Linear(dim, self.dim, bias=qkv_bias)
- self.v = nn.Linear(dim, self.dim, bias=qkv_bias)
- def forward(self, x):
- B, N, C = x.shape
- q = self.q(x)
- q = q.reshape(B, N, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
- k = self.k(x)
- k = k.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 3, 1)
- v = self.v(x)
- v = v.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
- return q,k,v
-
-
-
- # ## Efficient Self attention
-
- # In[23]:
-
-
- class E_MHA(nn.Module):
- """
- Efficient Multi-Head Self Attention
- """
- def __init__(self, head_dim=32, qk_scale=None,attn_drop=0):
- super().__init__()
- self.scale = qk_scale or head_dim ** -0.5
- self.attn_drop = nn.Dropout(attn_drop)
- def forward(self, q,k,v,b,n,c):
- attn = (q @ k) * self.scale
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
- attn = (attn @ v).transpose(1, 2).reshape(b,n,c)
- return attn
-
-
- # ## MLP and res
-
- # In[24]:
-
-
- class Mlp(nn.Module):
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
-
- def forward(self, x):
- x1= x
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- x+=x1
- del x1
- return x
-
-
- # ## Efficient Multi-Head Double-Viwe Attention
-
- # In[25]:
-
-
- class E_MTDVA(nn.Module):
- """
- Efficient Multi-Head Double-Viwe Attention
- x_mh,x_ah,x_mv,x_av->forward
- """
- def __init__(self, dim, head_dim=32, out_dim=None, qkv_bias=True, qk_scale=None,
- attn_drop=0, proj_drop=0., sr_ratio=1):
- super(E_MTDVA,self).__init__()
- self.dim = dim
- self.out_dim = out_dim if out_dim is not None else dim
- self.num_heads = self.dim // head_dim
- self.scale = qk_scale or head_dim ** -0.5
- self.qkv = E_PQKV( dim, head_dim, out_dim, qkv_bias, qk_scale,attn_drop, proj_drop, sr_ratio)
- self.sa = E_MHA(head_dim,qk_scale,attn_drop)
- self.ca = E_MHA(head_dim,qk_scale,attn_drop)
- self.proj = Mlp(self.dim,self.dim*4, self.out_dim)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj_drop = nn.Dropout(proj_drop)
- self.norm = nn.LayerNorm(dim)
- self.sr_ratio = sr_ratio
- self.N_ratio = sr_ratio ** 2
- if sr_ratio > 1:
- self.sr = nn.AvgPool1d(kernel_size=self.N_ratio, stride=self.N_ratio)
- self.norm = nn.BatchNorm1d(dim, eps=NORM_EPS)
- self.is_bn_merged = False
-
- def merge_bn(self, pre_bn):
- merge_pre_bn(self.q, pre_bn)
- if self.sr_ratio > 1:
- merge_pre_bn(self.k, pre_bn, self.norm)
- merge_pre_bn(self.v, pre_bn, self.norm)
- else:
- merge_pre_bn(self.k, pre_bn)
- merge_pre_bn(self.v, pre_bn)
- self.is_bn_merged = True
- # def pro_pkv(x_mh)
-
- def forward(self, x_mh,x_ah,x_mv,x_av):
-
- h = torch.cat([x_mh,x_ah],dim = 1)
- v = torch.cat([x_mv,x_av],dim = 1)
- B, N, C = h.shape
- #selfatt
- h = self.norm(h)
- v = self.norm(v)
- q_h,k_h,v_h = self.qkv(h)
- q_v,k_v,v_v = self.qkv(v)
- attn_h = self.sa(q_h,k_h,v_h, B, N, C)
- attn_v = self.sa(q_v,k_v,v_v, B, N, C)
- h+=attn_h
- v+=attn_v
-
- ##ffn
- h = self.norm(h)
- v = self.norm(v)
- h = self.proj(h)
- v = self.proj(v)
- #crossatt
-
- h = self.norm(h)
- v = self.norm(v)
- q_h,k_h,v_h = self.qkv(h)
- q_v,k_v,v_v = self.qkv(v)
- attn_h = self.sa(q_h,k_h,v_h, B, N, C)
- attn_v = self.sa(q_v,k_v,v_v, B, N, C)
- h+=attn_h
- v+=attn_v
- h = self.norm(h)
- v = self.norm(v)
- q_h,k_h,v_h = self.qkv(h)
- q_v,k_v,v_v = self.qkv(v)
- attn_h=self.ca(q_h,k_v,v_h, B, N, C)
- attn_v=self.ca(q_h,k_v,v_h, B, N, C)
- h+=attn_h
- v+=attn_v
- ##ffn
- h = self.norm(h)
- v = self.norm(v)
- h = self.proj(h)
- v = self.proj(v)
- #output
- attn_h = h.mean(1)
- attn_v = v.mean(1)
- x = torch.cat([attn_h,attn_v],dim=1)
- del h,v,q_h,k_h,v_h,q_v,k_v,v_v
- return x
-
-
- # In[26]:
-
-
- class ShardExamine(nn.Module):
-
- def __init__(self,dim,
- in_channels,
- out_channels,
- out_putdim,
- num_head=64,
- qkv_bias=False,
- attn_drop=0.1,
- proj_drop=0.1,
- stride=1
- ):
- super(ShardExamine,self).__init__()
- # super(LocalCoccurrence, self).__init__()
- self.PEV = PatchEmbed(in_channels,out_channels,stride*2,mode = "V")
- self.PEH = PatchEmbed(in_channels,out_channels,stride,mode = "H")
- self.proj = nn.Linear(out_channels*2, out_putdim)
- self.attn = E_MTDVA(dim,num_head,None,True,qkv_bias,attn_drop,proj_drop)
-
- def forward(self, u_m, u_a):
- u_mv = self.PEV(u_m)
- u_av = self.PEV(u_a)
- u_m = self.PEH(u_m)
- u_a = self.PEH(u_a)
- B,C,H,W = u_mv.shape
- L = H*W
- dim = C
-
- u_mv = u_mv.reshape(B,dim,L).permute(0,2,1)
- u_av = u_av.reshape(B,dim,L).permute(0,2,1)
- u_m = u_m.reshape(B,dim,L).permute(0,2,1)
- u_a = u_a.reshape(B,dim,L).permute(0,2,1)
-
-
- x = self.attn(u_m, u_a,u_mv, u_av)# x_mh,x_ah,x_mv,x_av
- # x_m, x_a = self.attn(u_m, u_a)
- # x_mv, x_av = self.attn(u_mv, u_a)
- # gap_m = x_m.mean(1)
- # x = x.mean(1)
- x1 = self.proj(x)
-
- return x1,x
-
-
- def get_logger(filename):
- from logging import getLogger, INFO, StreamHandler, FileHandler, Formatter
- logger = getLogger(__name__)
- logger.setLevel(INFO)
- handler1 = StreamHandler()
- handler1.setFormatter(Formatter("%(message)s"))
- handler2 = FileHandler(filename=f"{filename}.log")
- handler2.setFormatter(Formatter("%(message)s"))
- logger.addHandler(handler1)
- logger.addHandler(handler2)
- return logger
-
- # In[5]:
- LOGGER = get_logger('/root/share/train')
-
-
- # In[28]:
-
-
- class basNet(nn.Module):
- """
- only two viwe
- """
- def load_pretrain(self, ):
- return
-
- def __init__(self,K=12288, m=0.999, T=0.07,dim=128,nc=512):
- super(basNet, self).__init__()
- # self.output_type = ['inference', 'loss']
-
- # self.K = K
- # self.m = m
- # self.T = T
- self.encoder_q = timm.create_model ('tf_efficientnet_b4_ns',#m 512 s 256 b1 320 b2 352
- pretrained=False,
- drop_rate = 0.2,
- drop_path_rate = 0.1,
- num_classes=nc
- )
- self.encoder_q.load_state_dict(torch.load('./tf_efficientnet_b4_ns.pth'))
- self.encoder_q = torchvision.models._utils.IntermediateLayerGetter(self.encoder_q,
- {'blocks': 'feat1',
- # 'conv_head':'feat2',
- 'classifier':'feat2'})
- self.encoder_q.fc = nn.Sequential(
- nn.LayerNorm(512),
- nn.GELU(),
- nn.Linear(512, 128),
- )
- self.att = ShardExamine(128,448,128,256)
- # output head
- self.mlp = nn.Sequential(
- nn.LayerNorm(1280),
- nn.Linear(1280, 512),
- nn.GELU(),
- nn.Linear(512, 128),
- )#<todo> mlp needs to be deep if backbone is strong?
- self.aux = nn.Sequential(
- nn.LayerNorm(1024),
- nn.Linear(1024, 128),
- nn.GELU(),
- nn.Linear(128, 1),
- )
- self.cancer = nn.Linear(128,1)
- self.auxcancer = nn.Linear(512,1)
- # self.labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
-
- def forward(self, x):
- """
- Input:
- im_q: a batch of query images
- im_k: a batch of key images
- Output:
- logits, targets
- """
-
- batch_size,C,H,W = x.shape
- x = x.reshape(-1, C, H, W)
- x_m =torch.split(x,H//2,dim=-1)[0]
-
- # print(x_m.shape)
- x_a =torch.split(x,H//2,dim=-1)[1]
-
- # compute query features
- q_mp = self.encoder_q(x_m) # queries: NxC
- q_ap = self.encoder_q(x_a) # queries: NxC
- can1 = self.auxcancer(q_mp['feat2'])
- torch.sigmoid(can1)
- can2 = self.auxcancer(q_ap['feat2'])
- torch.sigmoid(can2)
- can3 = self.aux(torch.cat([q_ap['feat2'],q_mp['feat2']],-1))
- torch.sigmoid(can3)
- attn,attn1= self.att(q_mp['feat1'],q_mp['feat1'])
- q_m = q_mp['feat2']
- q_a = q_ap['feat2']
- q_m = nn.functional.normalize(q_m, dim=1)
- q_a = nn.functional.normalize(q_a, dim=1)
-
- last = torch.cat([q_a,q_m,attn],-1)
- last = self.mlp(last)
- cancer = self.cancer(last).reshape(-1)
- cancer = torch.sigmoid(cancer)
- # if cancer>0.5:
- return cancer,can1,can2,can3
- # return features
-
-
- # # loss
-
- # ## 辅助损失函数
-
- # In[29]:
-
-
- def aux_similiarity(x1,x2):#features[0/1/2][0],features[3/4/5][0]
- p12 = (x1*x2).sum(-1)
- p1 = torch.sqrt((x1*x1).sum(-1))
- p2 = torch.sqrt((x2*x2).sum(-1))
- s = sum(nn.Flatten()(p12/(p1*p2+1e-6)))/2
- s = s.sum()/len(s)
- return s
-
-
- # # 相似度损失函数
-
- # In[30]:
-
-
- def similiarity(x):#features[0/1/2][0],features[3/4/5][0]
- x1 = torch.split(x,len(x)//2,dim=-1)[0]
- x2 = torch.split(x,len(x)//2,dim=-1)[1]
- p12 = (x1*x2).sum(-1)
- p1 = torch.sqrt((x1*x1).sum(-1))
- p2 = torch.sqrt((x2*x2).sum(-1))
- s = p12/(p1*p2+1e-6)
- return s
-
-
-
- class RASampler(torch.utils.data.Sampler):
- """Sampler that restricts data loading to a subset of the dataset for distributed,
- with repeated augmentation.
- It ensures that different each augmented version of a sample will be visible to a
- different process (GPU)
- Heavily based on torch.utils.data.DistributedSampler
- """
-
- def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
- if num_replicas is None:
- if not dist.is_available():
- raise RuntimeError("Requires distributed package to be available")
- num_replicas = dist.get_world_size()
- if rank is None:
- if not dist.is_available():
- raise RuntimeError("Requires distributed package to be available")
- rank = dist.get_rank()
- self.dataset = dataset
- self.num_replicas = num_replicas
- self.rank = rank
- self.epoch = 0
- self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
- self.total_size = self.num_samples * self.num_replicas
- # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
- self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
- self.shuffle = shuffle
-
- def __iter__(self):
- # deterministically shuffle based on epoch
- g = torch.Generator()
- g.manual_seed(self.epoch)
- if self.shuffle:
- indices = torch.randperm(len(self.dataset), generator=g).tolist()
- else:
- indices = list(range(len(self.dataset)))
-
- # add extra samples to make it evenly divisible
- indices = [ele for ele in indices for i in range(3)]
- indices += indices[:(self.total_size - len(indices))]
- assert len(indices) == self.total_size
-
- # subsample
- indices = indices[self.rank:self.total_size:self.num_replicas]
- assert len(indices) == self.num_samples
-
- return iter(indices[:self.num_selected_samples])
-
- def __len__(self):
- return self.num_selected_samples
-
- def set_epoch(self, epoch):
- self.epoch = epoch
-
-
- # In[34]:
-
-
- from torch.utils.data.distributed import DistributedSampler
- class WeightedRandomSamplerDDP(DistributedSampler):
- r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
-
- Args:
- data_set: Dataset used for sampling.
- weights (sequence) : a sequence of weights, not necessary summing up to one
- num_replicas (int, optional): Number of processes participating in
- distributed training. By default, :attr:`world_size` is retrieved from the
- current distributed group.
- rank (int, optional): Rank of the current process within :attr:`num_replicas`.
- By default, :attr:`rank` is retrieved from the current distributed
- group.
- num_samples (int): number of samples to draw
- replacement (bool): if ``True``, samples are drawn with replacement.
- If not, they are drawn without replacement, which means that when a
- sample index is drawn for a row, it cannot be drawn again for that row.
- generator (Generator): Generator used in sampling.
- """
- # weights: Tensor
- # num_samples: int
- # replacement: bool
-
- def __init__(self, data_set, weight, num_replicas: int, rank: int,
- replacement: bool = True, generator=None) -> None:
- super(WeightedRandomSamplerDDP, self).__init__(data_set, num_replicas, rank)
- if not isinstance(num_samples, int) or isinstance(num_samples, bool) or num_samples <= 0:
- raise ValueError("num_samples should be a positive integer "
- "value, but got num_samples={}".format(num_samples))
- if not isinstance(replacement, bool):
- raise ValueError("replacement should be a boolean value, but got "
- "replacement={}".format(replacement))
- self.weights = torch.as_tensor(weights, dtype=torch.double)
- self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
- self.replacement = replacement
- self.generator = generator
- self.num_replicas = num_replicas
- self.rank = rank
- self.weights = self.weights[self.rank::self.num_replicas]
- self.num_samples = self.num_samples // self.num_replicas
-
- def __iter__(self):
- rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)
- rand_tensor = self.rank + rand_tensor * self.num_replicas
- return iter(rand_tensor.tolist())
-
- def __len__(self):
- return self.num_samples
-
-
- # # train_main
-
- # In[35]:
-
-
- def get_logger(filename):
- from logging import getLogger, INFO, StreamHandler, FileHandler, Formatter
- logger = getLogger(__name__)
- logger.setLevel(INFO)
- handler1 = StreamHandler()
- handler1.setFormatter(Formatter("%(message)s"))
- handler2 = FileHandler(filename=f"{filename}.log")
- handler2.setFormatter(Formatter("%(message)s"))
- logger.addHandler(handler1)
- logger.addHandler(handler2)
- return logger
-
- # In[5]:
- LOGGER = get_logger('./train')
-
-
- # In[36]:
-
-
- def main(args):
- print(args)
- device = torch.device(args.device)
- seed = args.seed + get_rank()
- torch.manual_seed(seed)
- np.random.seed(seed)
- cudnn.benchmark = True
- print('dataset_train')
- dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
- dataset_val, _ = build_dataset(is_train=False, args=args)
- tmp = build_weight(args.data_path)
- sampler_train = torch.utils.data.WeightedRandomSampler(tmp,len(dataset_train))
- sampler_val = torch.utils.data.SequentialSampler(dataset_val)
-
- data_loader_train = torch.utils.data.DataLoader(
- dataset_train, sampler=sampler_train,
- batch_size=args.batch_size,
- num_workers=args.num_workers,
- pin_memory=args.pin_mem,
- drop_last=True,
- )
-
- data_loader_val = torch.utils.data.DataLoader(
- dataset_val, sampler=sampler_val,
- batch_size=250,
- num_workers=args.num_workers,
- pin_memory=args.pin_mem,
- drop_last=False
- )
- print(f"Creating model: {args.model}")
- model = basNet()
- model.to(device)
- model_ema = None
- if args.distributed:
- model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
- model_without_ddp = model.module
- else:
- model_without_ddp = model
-
- linear_scaled_lr = args.lr * args.batch_size * get_world_size() / 512.0
-
- args.lr = linear_scaled_lr
- optimizer = create_optimizer(args, model_without_ddp)
- #-------修改
-
- loss_scaler = NativeScaler()
-
- # lr_scheduler, _ = create_scheduler(args, optimizer)
- from timm.scheduler import CosineLRScheduler
- loss_scaler = NativeScaler()
- nbatch = len(data_loader_train)
- warmup = args.warmup_epochs * nbatch
- nsteps = args.epochs * nbatch
- # lr_scheduler, _ = create_scheduler(args, optimizer)
- lr_scheduler = CosineLRScheduler(optimizer,
- warmup_t=warmup, warmup_lr_init=0.0, warmup_prefix=True, # 1 epoch of warmup
- t_initial=(nsteps - warmup), lr_min=args.warmup_lr)
-
- class_criterion = nn.BCEWithLogitsLoss()
- nce_criterion = nn.CrossEntropyLoss()
- if not args.output_dir:
- args.output_dir = args.model
- # if is_main_process():
- import os
- if not os.path.exists(args.model):
- os.mkdir(args.model)
-
- output_dir = Path(args.output_dir)
- if args.eval:
- if hasattr(model.module, "merge_bn"):
- print("Merge pre bn to speedup inference.")
- model.module.merge_bn()
- test_stats = evaluate(data_loader_val, model, device)
- print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
- return
-
-
- print(f"Start training for {args.epochs} epochs")
- start_time = time.time()
- max_accuracy = 0.0
- time_val = 0.0
- tb = time.time()
-
- for epoch in range(args.start_epoch, args.epochs):
- LOGGER.info("epoch_start_time" + str(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
- if args.distributed:
- data_loader_train.sampler.set_epoch(epoch)
- optimizer.zero_grad()
- model.train()
- cls1 = 0
- cls2 = 0
- cls3 = 0
- cls4 = 0
- for k,i in enumerate(data_loader_train):
- samples = i[0].to(device, non_blocking=True)
- targets = i[1].to(device, non_blocking=True)
- with torch.cuda.amp.autocast():
- cancer,can1,can2,can3= model(samples)
- loss_cls = class_criterion(cancer,targets.float())
- loss_clsvalue = loss_cls.item()
- loss_cls1 = class_criterion(can1.squeeze(-1),targets.float())
- loss_clsvalue1 = loss_cls1.item()
- loss_cls2 = class_criterion(can2.squeeze(-1),targets.float())
- loss_clsvalue2 = loss_cls2.item()
- loss_cls3 = class_criterion(can3.squeeze(-1),targets.float())
- loss_clsvalue3 = loss_cls3.item()
- cls1 += loss_cls
- cls2 += loss_cls1
- cls3 += loss_cls2
- cls4 += loss_cls3
- loss_value = loss_cls3+loss_cls1+loss_cls2+0.8*loss_cls
- loss_value = torch.as_tensor(loss_value).to(device)
- if k%50==0:
- tt = (time.time() - tb) / 60
- LOGGER.info('ibatch %d %.4f %.2f min' %(k + 1, cls1+cls2+cls3+cls4,tt))
- LOGGER.info('loss %.4f %.4f %.4f %.4f' %( cls1,cls2,cls3,cls4))
- if not math.isfinite(loss_value):
- print("Loss is {}, stopping training".format(loss_value))
- sys.exit(1)
- optimizer.zero_grad()
- is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
- loss_scaler(loss_value, optimizer, clip_grad=args.clip_grad,
- parameters=model.parameters(), create_graph=is_second_order)
- torch.cuda.synchronize()
- if model_ema is not None:
- model_ema.update(model)
- gc.collect()
- LOGGER.info(loss_cls1 ,loss_cls1 ,loss_cls2 ,loss_cls3)
-
- # LOGGER.info("Averaged stats:", metric_logger)
-
- lr_scheduler.step(epoch)
- if args.output_dir:
- checkpoint_paths = [output_dir / 'checkpoint.pth']
- for checkpoint_path in checkpoint_paths:
- save_on_master({
- 'model': model_without_ddp.state_dict(),
- 'optimizer': optimizer.state_dict(),
- 'lr_scheduler': lr_scheduler.state_dict(),
- 'epoch': epoch,
- 'scaler': loss_scaler.state_dict(),
- 'args': args,
- }, checkpoint_path)
-
- test_stats = evaluate(data_loader_val, model, device)
- print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
- if test_stats["acc1"] > max_accuracy:
- if args.output_dir:
- checkpoint_paths = [output_dir / 'checkpoint_best.pth']
- for checkpoint_path in checkpoint_paths:
- save_on_master({
- 'model': model_without_ddp.state_dict(),
- 'optimizer': optimizer.state_dict(),
- 'lr_scheduler': lr_scheduler.state_dict(),
- 'epoch': epoch,
- 'args': args,
- }, checkpoint_path)
- max_accuracy = max(max_accuracy, test_stats["acc1"])
- print(f'Max accuracy: {max_accuracy:.2f}%')
-
- log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
- **{f'test_{k}': v for k, v in test_stats.items()},
- 'epoch': epoch}
-
- if args.output_dir :# and is_main_process()
- with (output_dir / "log.txt").open("a") as f:
- f.write(json.dumps(log_stats) + "\n")
-
- total_time = time.time() - start_time
- total_time_str = str(datetime.timedelta(seconds=int(total_time)))
- print('Training time {}'.format(total_time_str))
-
-
- # In[38]:
-
-
- def get_args_parser():
- parser = argparse.ArgumentParser('Next-ViT training and evaluation script', add_help=False)
- parser.add_argument('--batch-size', default=12, type=int)
- parser.add_argument('--epochs', default=30, type=int)
-
- # Model parameters
- parser.add_argument('--model', default='pvt_small', type=str, metavar='MODEL',
- help='Name of model to train')
- parser.add_argument('--input-size', default=1024, type=int, help='images input size')
-
- parser.add_argument('--drop', type=float, default=0.1, metavar='PCT',
- help='Dropout rate (default: 0.)')
- parser.add_argument('--drop-path', type=float, default=0.0, metavar='PCT',
- help='Drop path rate (default: 0.1)')
- parser.add_argument('--flops', type=float, default=0.1, metavar='PCT',
- help='Drop path rate (default: 0.1)')
- # Optimizer parameters
- parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
- help='Optimizer (default: "adamw"')
- parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
- help='Optimizer Epsilon (default: 1e-8)')
- parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
- help='Optimizer Betas (default: None, use opt default)')
- parser.add_argument('--clip-grad', type=float, default=5, metavar='NORM',
- help='Clip gradient norm (default: None, no clipping)')
- parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
- help='SGD momentum (default: 0.9)')
- parser.add_argument('--weight-decay', type=float, default=0.05,
- help='weight decay (default: 0.05)')
- # Learning rate schedule parameters
- parser.add_argument('--sched', default='sched', type=str, metavar='SCHEDULER',
- help='LR scheduler (default: "cosine"')
- parser.add_argument('--lr', type=float, default=5e-3, metavar='LR',
- help='learning rate (default: 5e-4)')
- parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
- help='learning rate noise on/off epoch percentages')
- parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
- help='learning rate noise limit percent (default: 0.67)')
- parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
- help='learning rate noise std-dev (default: 1.0)')
- parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
- help='warmup learning rate (default: 1e-6)')
- parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
- help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
-
- parser.add_argument('--decay-epochs', type=float, default=40, metavar='N',
- help='epoch interval to decay LR')
- parser.add_argument('--warmup_epochs', type=int, default=3, metavar='N',
- help='epochs to warmup LR, if scheduler supports')
- parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
- help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
- parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
- help='patience epochs for Plateau LR scheduler (default: 10')
- parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
- help='LR decay rate (default: 0.1)')
-
- # Augmentation parameters
- parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
- help='Color jitter factor (default: 0.4)')
- parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
- help='Use AutoAugment policy. "v0" or "original". " + \
- "(default: rand-m9-mstd0.5-inc1)'),
- parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
- parser.add_argument('--train-interpolation', type=str, default='bicubic',
- help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
-
- parser.add_argument('--repeated-aug', action='store_true')
- parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
- parser.set_defaults(repeated_aug=False)
-
- # * Random Erase params
- parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
- help='Random erase prob (default: 0.25)')
- parser.add_argument('--remode', type=str, default='pixel',
- help='Random erase mode (default: "pixel")')
- parser.add_argument('--recount', type=int, default=1,
- help='Random erase count (default: 1)')
- parser.add_argument('--resplit', action='store_true', default=False,
- help='Do not random erase first (clean) augmentation split')
- parser.add_argument('--distributed', action='store_true', default=False,
- help='Do not random erase first (clean) augmentation split')
- # * Mixup params
- parser.add_argument('--mixup', type=float, default=0.8,
- help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
- parser.add_argument('--cutmix', type=float, default=1.0,
- help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
- parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
- help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
- parser.add_argument('--mixup-prob', type=float, default=1.0,
- help='Probability of performing mixup or cutmix when either/both is enabled')
- parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
- help='Probability of switching to cutmix when both mixup and cutmix enabled')
- parser.add_argument('--mixup-mode', type=str, default='batch',
- help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
-
- # * Finetuning params
- parser.add_argument('--finetune', action='store_true', help='Perform finetune.')
-
- # Dataset parameters
- parser.add_argument('--data-path', default='./data', type=str,
- help='dataset path')
- parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'],
- type=str, help='Image Net dataset path')
- parser.add_argument('--use-mcloader', action='store_true', default=False, help='Use mcloader')
- parser.add_argument('--inat-category', default='name',
- choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
- type=str, help='semantic granularity')
-
- parser.add_argument('--output-dir', default='../outputdir',
- help='path where to save, empty for no saving')
- parser.add_argument('--device', default='cuda',
- help='device to use for training / testing')
- parser.add_argument('--seed', default=0, type=int)
- parser.add_argument('--resume', default='./nextvit_small_in1k6m_384.pth', help='resume from checkpoint')
- parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
- help='start epoch')
- parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
- parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
- parser.add_argument('--num_workers', default=8, type=int)
- parser.add_argument('--pin-mem', action='store_true',
- help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
- parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
- help='')
- parser.set_defaults(pin_mem=True)
-
- # distributed training parameters
- parser.add_argument('--world_size', default=1, type=int,
- help='number of distributed processes')
- parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
-
- # test throught
- parser.add_argument('--throughout', action='store_true', help='Perform throughout only')
- return parser
-
-
- # In[39]:
-
-
- parser = argparse.ArgumentParser('Twins training and evaluation script', parents=[get_args_parser()])
-
-
- # In[40]:
-
-
- args = parser.parse_args([])
-
-
- # In[41]:
-
-
- if args.output_dir:
- Path(args.output_dir).mkdir(parents=True, exist_ok=True)
- main(args)
|