|
- # Copyright (c) OpenMMLab. All rights reserved.
- import argparse
-
- import torch.nn.functional as F
- import torchvision
- import torchvision.transforms as transforms
- from torch.optim import SGD
-
- from mmengine.evaluator import BaseMetric
- from mmengine.model import BaseModel
- from mmengine.runner import Runner
-
-
- class MMResNet50(BaseModel):
-
- def __init__(self):
- super().__init__()
- self.resnet = torchvision.models.resnet50()
-
- def forward(self, imgs, labels, mode):
- x = self.resnet(imgs)
- if mode == 'loss':
- return {'loss': F.cross_entropy(x, labels)}
- elif mode == 'predict':
- return x, labels
-
-
- class Accuracy(BaseMetric):
-
- def process(self, data_batch, data_samples):
- score, gt = data_samples
- self.results.append({
- 'batch_size': len(gt),
- 'correct': (score.argmax(dim=1) == gt).sum().cpu(),
- })
-
- def compute_metrics(self, results):
- total_correct = sum(item['correct'] for item in results)
- total_size = sum(item['batch_size'] for item in results)
- return dict(accuracy=100 * total_correct / total_size)
-
-
- def parse_args():
- parser = argparse.ArgumentParser(description='Distributed Training')
- parser.add_argument(
- '--launcher',
- choices=['none', 'pytorch', 'slurm', 'mpi'],
- default='none',
- help='job launcher')
- parser.add_argument('--local_rank', type=int, default=0)
-
- args = parser.parse_args()
- return args
-
-
- def main():
- args = parse_args()
- norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
- train_set = torchvision.datasets.CIFAR10(
- 'data/cifar10',
- train=True,
- download=True,
- transform=transforms.Compose([
- transforms.RandomCrop(32, padding=4),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- transforms.Normalize(**norm_cfg)
- ]))
- valid_set = torchvision.datasets.CIFAR10(
- 'data/cifar10',
- train=False,
- download=True,
- transform=transforms.Compose(
- [transforms.ToTensor(),
- transforms.Normalize(**norm_cfg)]))
- train_dataloader = dict(
- batch_size=32,
- dataset=train_set,
- sampler=dict(type='DefaultSampler', shuffle=True),
- collate_fn=dict(type='default_collate'))
- val_dataloader = dict(
- batch_size=32,
- dataset=valid_set,
- sampler=dict(type='DefaultSampler', shuffle=False),
- collate_fn=dict(type='default_collate'))
- runner = Runner(
- model=MMResNet50(),
- work_dir='./work_dirs',
- train_dataloader=train_dataloader,
- optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
- train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
- val_dataloader=val_dataloader,
- val_cfg=dict(),
- val_evaluator=dict(type=Accuracy),
- launcher=args.launcher,
- )
- runner.train()
-
-
- if __name__ == '__main__':
- main()
|