|
- # Copyright (c) OpenMMLab. All rights reserved.
- import os
-
- from mmengine import Config
- from mmengine.logging import MessageHub, MMLogger
- from mmengine.registry import DefaultScope
- from mmengine.testing import RunnerTestCase
- from mmengine.visualization import Visualizer
-
-
- class TestRunnerTestCase(RunnerTestCase):
-
- def test_setup(self):
- self.assertIsInstance(self.epoch_based_cfg, Config)
- self.assertIsInstance(self.iter_based_cfg, Config)
- self.assertIn('MASTER_ADDR', self.dist_cfg)
- self.assertIn('MASTER_PORT', self.dist_cfg)
- self.assertIn('RANK', self.dist_cfg)
- self.assertIn('WORLD_SIZE', self.dist_cfg)
- self.assertIn('LOCAL_RANK', self.dist_cfg)
-
- def test_tearDown(self):
- self.tearDown()
- self.assertEqual(MMLogger._instance_dict, {})
- self.assertEqual(MessageHub._instance_dict, {})
- self.assertEqual(Visualizer._instance_dict, {})
- self.assertEqual(DefaultScope._instance_dict, {})
- # tearDown should not be called twice.
- self.tearDown = super(RunnerTestCase, self).tearDown
-
- def test_build_runner(self):
- runner = self.build_runner(self.epoch_based_cfg)
- runner.train()
- runner.val()
- runner.test()
-
- runner = self.build_runner(self.iter_based_cfg)
- runner.train()
- runner.val()
- runner.test()
-
- def test_experiment_name(self):
- runner1 = self.build_runner(self.epoch_based_cfg)
- runner2 = self.build_runner(self.epoch_based_cfg)
- self.assertNotEqual(runner1.experiment_name, runner2.experiment_name)
-
- def test_init_dist(self):
- self.setup_dist_env()
- self.assertEqual(
- str(self.dist_cfg['MASTER_PORT']), os.environ['MASTER_PORT'])
- self.assertEqual(self.dist_cfg['MASTER_ADDR'],
- os.environ['MASTER_ADDR'])
- self.assertEqual(self.dist_cfg['RANK'], os.environ['RANK'])
- self.assertEqual(self.dist_cfg['LOCAL_RANK'], os.environ['LOCAL_RANK'])
- self.assertEqual(self.dist_cfg['WORLD_SIZE'], os.environ['WORLD_SIZE'])
- fisrt_port = os.environ['MASTER_ADDR']
- self.setup_dist_env()
- self.assertNotEqual(fisrt_port, os.environ['MASTER_PORT'])
|