|
- # Copyright (c) OpenMMLab. All rights reserved.
- import copy
- import logging
- import os
- import shutil
- import tempfile
- import time
- from unittest import TestCase
- from uuid import uuid4
-
- import torch
- import torch.nn as nn
- from torch.distributed import destroy_process_group
- from torch.utils.data import Dataset
-
- import mmengine.hooks # noqa F401
- import mmengine.optim # noqa F401
- from mmengine.config import Config
- from mmengine.dist import is_distributed
- from mmengine.evaluator import BaseMetric
- from mmengine.logging import MessageHub, MMLogger
- from mmengine.model import BaseModel
- from mmengine.registry import DATASETS, METRICS, MODELS, DefaultScope
- from mmengine.runner import Runner
- from mmengine.visualization import Visualizer
-
-
- class ToyModel(BaseModel):
-
- def __init__(self, data_preprocessor=None):
- super().__init__(data_preprocessor=data_preprocessor)
- self.linear1 = nn.Linear(2, 2)
- self.linear2 = nn.Linear(2, 1)
-
- def forward(self, inputs, data_samples=None, mode='tensor'):
- if isinstance(inputs, list):
- inputs = torch.stack(inputs)
- if isinstance(data_samples, list):
- data_samples = torch.stack(data_samples)
- outputs = self.linear1(inputs)
- outputs = self.linear2(outputs)
-
- if mode == 'tensor':
- return outputs
- elif mode == 'loss':
- loss = (data_samples - outputs).sum()
- outputs = dict(loss=loss)
- return outputs
- elif mode == 'predict':
- return outputs
-
-
- class ToyDataset(Dataset):
- METAINFO = dict() # type: ignore
- data = torch.randn(12, 2)
- label = torch.ones(12)
-
- @property
- def metainfo(self):
- return self.METAINFO
-
- def __len__(self):
- return self.data.size(0)
-
- def __getitem__(self, index):
- return dict(inputs=self.data[index], data_samples=self.label[index])
-
-
- class ToyMetric(BaseMetric):
-
- def __init__(self, collect_device='cpu', dummy_metrics=None):
- super().__init__(collect_device=collect_device)
- self.dummy_metrics = dummy_metrics
-
- def process(self, data_batch, predictions):
- result = {'acc': 1}
- self.results.append(result)
-
- def compute_metrics(self, results):
- return dict(acc=1)
-
-
- class RunnerTestCase(TestCase):
- """A test case to build runner easily.
-
- `RunnerTestCase` will do the following things:
-
- 1. Registers a toy model, a toy metric, and a toy dataset, which can be
- used to run the `Runner` successfully.
- 2. Provides epoch based and iteration based cfg to build runner.
- 3. Provides `build_runner` method to build runner easily.
- 4. Clean the global variable used by the runner.
- """
- dist_cfg = dict(
- MASTER_ADDR='127.0.0.1',
- MASTER_PORT=29600,
- RANK='0',
- WORLD_SIZE='1',
- LOCAL_RANK='0')
-
- def setUp(self) -> None:
- self.temp_dir = tempfile.TemporaryDirectory()
- # Prevent from registering module with the same name by other unit
- # test. These registries will be cleared in `tearDown`
- MODELS.register_module(module=ToyModel, force=True)
- METRICS.register_module(module=ToyMetric, force=True)
- DATASETS.register_module(module=ToyDataset, force=True)
- epoch_based_cfg = dict(
- work_dir=self.temp_dir.name,
- model=dict(type='ToyModel'),
- train_dataloader=dict(
- dataset=dict(type='ToyDataset'),
- sampler=dict(type='DefaultSampler', shuffle=True),
- batch_size=3,
- num_workers=0),
- val_dataloader=dict(
- dataset=dict(type='ToyDataset'),
- sampler=dict(type='DefaultSampler', shuffle=False),
- batch_size=3,
- num_workers=0),
- val_evaluator=[dict(type='ToyMetric')],
- test_dataloader=dict(
- dataset=dict(type='ToyDataset'),
- sampler=dict(type='DefaultSampler', shuffle=False),
- batch_size=3,
- num_workers=0),
- test_evaluator=[dict(type='ToyMetric')],
- optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.1)),
- train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
- val_cfg=dict(),
- test_cfg=dict(),
- default_hooks=dict(logger=dict(type='LoggerHook', interval=1)),
- custom_hooks=[],
- env_cfg=dict(dist_cfg=dict(backend='nccl')),
- experiment_name='test1')
- self.epoch_based_cfg = Config(epoch_based_cfg)
-
- # prepare iter based cfg.
- self.iter_based_cfg: Config = copy.deepcopy(self.epoch_based_cfg)
- self.iter_based_cfg.train_dataloader = dict(
- dataset=dict(type='ToyDataset'),
- sampler=dict(type='InfiniteSampler', shuffle=True),
- batch_size=3,
- num_workers=0)
- self.iter_based_cfg.log_processor = dict(by_epoch=False)
-
- self.iter_based_cfg.train_cfg = dict(by_epoch=False, max_iters=12)
- self.iter_based_cfg.default_hooks = dict(
- logger=dict(type='LoggerHook', interval=1),
- checkpoint=dict(
- type='CheckpointHook', interval=12, by_epoch=False))
-
- def tearDown(self):
- # `FileHandler` should be closed in Windows, otherwise we cannot
- # delete the temporary directory
- logging.shutdown()
- MMLogger._instance_dict.clear()
- Visualizer._instance_dict.clear()
- DefaultScope._instance_dict.clear()
- MessageHub._instance_dict.clear()
- MODELS.module_dict.pop('ToyModel', None)
- METRICS.module_dict.pop('ToyMetric', None)
- DATASETS.module_dict.pop('ToyDataset', None)
- self.temp_dir.cleanup()
- if is_distributed():
- destroy_process_group()
-
- def build_runner(self, cfg: Config):
- cfg.experiment_name = self.experiment_name
- runner = Runner.from_cfg(cfg)
- return runner
-
- @property
- def experiment_name(self):
- # Since runners could be built too fast to have a unique experiment
- # name(timestamp is the same), here we use uuid to make sure each
- # runner has the unique experiment name.
- return f'{self._testMethodName}_{time.time()} + ' \
- f'{uuid4()}'
-
- def setup_dist_env(self):
- self.dist_cfg['MASTER_PORT'] += 1
- os.environ['MASTER_PORT'] = str(self.dist_cfg['MASTER_PORT'])
- os.environ['MASTER_ADDR'] = self.dist_cfg['MASTER_ADDR']
- os.environ['RANK'] = self.dist_cfg['RANK']
- os.environ['WORLD_SIZE'] = self.dist_cfg['WORLD_SIZE']
- os.environ['LOCAL_RANK'] = self.dist_cfg['LOCAL_RANK']
-
- def clear_work_dir(self):
- logging.shutdown()
- for filename in os.listdir(self.temp_dir.name):
- filepath = os.path.join(self.temp_dir.name, filename)
- if os.path.isfile(filepath):
- os.remove(filepath)
- else:
- shutil.rmtree(filepath)
|